Skip to content

Commit ab28d5b

Browse files
k4pranhydrobeam
andauthored
Add OpenGL support for :class:~.PMobject (#1882)
* Handle PMobject in opengl * Handle thin-out method * Set opengl_mobject as TRIANGLES by default to align with Mesh() * Fix set_color_by_gradient method * Fix set_color_by_gradient for opengl * Temp fix due to set_color differences between opengl and cairo * Use true dot shader * Fix set_colors_by_radial_gradient for cairo and opengl * Use radius to normalise alphas * Remove point shaders * Fix fade_to and minor fixes * Fix fade_to issue * add separate point clouds for opengl * Try compatibility with pmobject * Add compatibility * Add compatibility for Mobjects2D * Cleanup pmobjects * Fix flake issues and rename to OpenGLPMobject * Update manim/mobject/types/dot_cloud.py Co-authored-by: Laith Bahodi <[email protected]> * Update manim/mobject/types/opengl_point_cloud_mobject.py Co-authored-by: Laith Bahodi <[email protected]> * Update manim/mobject/types/opengl_point_cloud_mobject.py Co-authored-by: Laith Bahodi <[email protected]> * Update manim/utils/color.py Co-authored-by: Laith Bahodi <[email protected]> * Update manim/mobject/types/opengl_point_cloud_mobject.py Co-authored-by: Laith Bahodi <[email protected]> * use attr for uniforms * Update manim/mobject/types/opengl_point_cloud_mobject.py Co-authored-by: Laith Bahodi <[email protected]> * Remove redundant methods * Add back reset_points * Update manim/mobject/types/opengl_point_cloud_mobject.py Co-authored-by: Laith Bahodi <[email protected]> Co-authored-by: Laith Bahodi <[email protected]>
1 parent fb68023 commit ab28d5b

File tree

11 files changed

+300
-34
lines changed

11 files changed

+300
-34
lines changed

manim/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@
7474
from .mobject.table import *
7575
from .mobject.three_d_utils import *
7676
from .mobject.three_dimensions import *
77+
from .mobject.types.dot_cloud import *
7778
from .mobject.types.image_mobject import *
79+
from .mobject.types.opengl_point_cloud_mobject import *
7880
from .mobject.types.point_cloud_mobject import *
7981
from .mobject.types.vectorized_mobject import *
8082
from .mobject.value_tracker import *

manim/mobject/opengl_compatibility.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .. import config
44
from .opengl_mobject import OpenGLMobject
55
from .opengl_three_dimensions import OpenGLSurface
6+
from .types.opengl_point_cloud_mobject import OpenGLPMobject
67
from .types.opengl_vectorized_mobject import OpenGLVMobject
78

89

@@ -25,6 +26,9 @@ def __new__(mcls, name, bases, namespace):
2526
base_names_to_opengl = {
2627
"Mobject": OpenGLMobject,
2728
"VMobject": OpenGLVMobject,
29+
"PMobject": OpenGLPMobject,
30+
"Mobject1D": OpenGLPMobject,
31+
"Mobject2D": OpenGLPMobject,
2832
"Surface": OpenGLSurface,
2933
}
3034

manim/mobject/opengl_mobject.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
# Positive shadow up to 1 makes a side opposite the light darker
6666
shadow=0.0,
6767
# For shaders
68-
render_primitive=moderngl.TRIANGLE_STRIP,
68+
render_primitive=moderngl.TRIANGLES,
6969
texture_paths=None,
7070
depth_test=False,
7171
# If true, the mobject will not get rotated according to camera position
@@ -134,7 +134,7 @@ def init_colors(self):
134134
self.set_color(self.color, self.opacity)
135135

136136
def init_points(self):
137-
# Typically implemented in subclass, unlpess purposefully left blank
137+
# Typically implemented in subclass, unless purposefully left blank
138138
pass
139139

140140
def set_data(self, data):
@@ -264,6 +264,11 @@ def set_points(self, points):
264264
self.refresh_bounding_box()
265265
return self
266266

267+
def apply_over_attr_arrays(self, func):
268+
for attr in self.get_array_attrs():
269+
setattr(self, attr, func(getattr(self, attr)))
270+
return self
271+
267272
def append_points(self, new_points):
268273
self.points = np.vstack([self.points, new_points])
269274
self.refresh_bounding_box()
@@ -1267,6 +1272,23 @@ def set_rgba_array(self, color=None, opacity=None, name="rgbas", recurse=True):
12671272
mob.data[name] = rgbas.copy()
12681273
return self
12691274

1275+
def set_rgba_array_direct(self, rgbas: np.ndarray, name="rgbas", recurse=True):
1276+
"""Directly set rgba data from `rgbas` and optionally do the same recursively
1277+
with submobjects. This can be used if the `rgbas` have already been generated
1278+
with the correct shape and simply need to be set.
1279+
1280+
Parameters
1281+
----------
1282+
rgbas
1283+
the rgba to be set as data
1284+
name
1285+
the name of the data attribute to be set
1286+
recurse
1287+
set to true to recursively apply this method to submobjects
1288+
"""
1289+
for mob in self.get_family(recurse):
1290+
mob.data[name] = rgbas.copy()
1291+
12701292
def set_color(self, color, opacity=None, recurse=True):
12711293
self.set_rgba_array(color, opacity, recurse=False)
12721294
# Recurse to submobjects differently from how set_rgba_array

manim/mobject/types/dot_cloud.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
__all__ = ["TrueDot", "DotCloud"]
2+
3+
import numpy as np
4+
5+
from ...constants import ORIGIN, RIGHT, UP
6+
from ...utils.color import YELLOW
7+
from .opengl_point_cloud_mobject import OpenGLPMobject
8+
9+
10+
class DotCloud(OpenGLPMobject):
11+
def __init__(
12+
self, color=YELLOW, stroke_width=2.0, radius=2.0, density=10, **kwargs
13+
):
14+
self.radius = radius
15+
self.epsilon = 1.0 / density
16+
super().__init__(
17+
stroke_width=stroke_width, density=density, color=color, **kwargs
18+
)
19+
20+
def init_points(self):
21+
self.points = np.array(
22+
[
23+
r * (np.cos(theta) * RIGHT + np.sin(theta) * UP)
24+
for r in np.arange(self.epsilon, self.radius, self.epsilon)
25+
# Num is equal to int(stop - start)/ (step + 1) reformulated.
26+
for theta in np.linspace(
27+
0, 2 * np.pi, num=int(2 * np.pi * (r + self.epsilon) / self.epsilon)
28+
)
29+
],
30+
dtype=np.float32,
31+
)
32+
33+
def make_3d(self, gloss=0.5, shadow=0.2):
34+
self.set_gloss(gloss)
35+
self.set_shadow(shadow)
36+
self.apply_depth_test()
37+
return self
38+
39+
40+
class TrueDot(DotCloud):
41+
def __init__(self, center=ORIGIN, stroke_width=2.0, **kwargs):
42+
self.radius = stroke_width
43+
super().__init__(points=[center], stroke_width=stroke_width, **kwargs)
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
__all__ = ["OpenGLPMobject", "OpenGLPGroup", "OpenGLPMPoint"]
2+
3+
import moderngl
4+
import numpy as np
5+
6+
from ...constants import *
7+
from ...mobject.opengl_mobject import OpenGLMobject
8+
from ...utils.bezier import interpolate
9+
from ...utils.color import BLACK, WHITE, YELLOW, color_gradient, color_to_rgba
10+
from ...utils.config_ops import _Uniforms
11+
from ...utils.iterables import resize_with_interpolation
12+
13+
14+
class OpenGLPMobject(OpenGLMobject):
15+
shader_folder = "true_dot"
16+
# Scale for consistency with cairo units
17+
OPENGL_POINT_RADIUS_SCALE_FACTOR = 0.01
18+
shader_dtype = [
19+
("point", np.float32, (3,)),
20+
("color", np.float32, (4,)),
21+
]
22+
23+
point_radius = _Uniforms()
24+
25+
def __init__(
26+
self, stroke_width=2.0, color=YELLOW, render_primitive=moderngl.POINTS, **kwargs
27+
):
28+
self.stroke_width = stroke_width
29+
super().__init__(color=color, render_primitive=render_primitive, **kwargs)
30+
self.point_radius = (
31+
self.stroke_width * OpenGLPMobject.OPENGL_POINT_RADIUS_SCALE_FACTOR
32+
)
33+
34+
def reset_points(self):
35+
self.rgbas = np.zeros((1, 4))
36+
self.points = np.zeros((0, 3))
37+
return self
38+
39+
def get_array_attrs(self):
40+
return ["points", "rgbas"]
41+
42+
def add_points(self, points, rgbas=None, color=None, opacity=None):
43+
"""
44+
points must be a Nx3 numpy array, as must rgbas if it is not None
45+
"""
46+
if rgbas is None and color is None:
47+
color = YELLOW
48+
self.append_points(points)
49+
# rgbas array will have been resized with points
50+
if color is not None:
51+
if opacity is None:
52+
opacity = self.rgbas[-1, 3]
53+
new_rgbas = np.repeat([color_to_rgba(color, opacity)], len(points), axis=0)
54+
elif rgbas is not None:
55+
new_rgbas = rgbas
56+
elif len(rgbas) != len(points):
57+
raise ValueError("points and rgbas must have same shape")
58+
self.rgbas = np.append(self.rgbas, new_rgbas, axis=0)
59+
return self
60+
61+
def thin_out(self, factor=5):
62+
"""
63+
Removes all but every nth point for n = factor
64+
"""
65+
for mob in self.family_members_with_points():
66+
num_points = mob.get_num_points()
67+
68+
def thin_func():
69+
return np.arange(0, num_points, factor)
70+
71+
if len(mob.points) == len(mob.rgbas):
72+
mob.set_rgba_array_direct(mob.rgbas[thin_func()])
73+
mob.set_points(mob.points[thin_func()])
74+
75+
return self
76+
77+
def set_color_by_gradient(self, *colors):
78+
self.rgbas = np.array(
79+
list(map(color_to_rgba, color_gradient(*colors, self.get_num_points())))
80+
)
81+
return self
82+
83+
def set_colors_by_radial_gradient(
84+
self, center=None, radius=1, inner_color=WHITE, outer_color=BLACK
85+
):
86+
start_rgba, end_rgba = list(map(color_to_rgba, [inner_color, outer_color]))
87+
if center is None:
88+
center = self.get_center()
89+
for mob in self.family_members_with_points():
90+
distances = np.abs(self.points - center)
91+
alphas = np.linalg.norm(distances, axis=1) / radius
92+
93+
mob.rgbas = np.array(
94+
np.array([interpolate(start_rgba, end_rgba, alpha) for alpha in alphas])
95+
)
96+
return self
97+
98+
def match_colors(self, pmobject):
99+
self.rgbas[:] = resize_with_interpolation(pmobject.rgbas, self.get_num_points())
100+
return self
101+
102+
def fade_to(self, color, alpha, family=True):
103+
rgbas = interpolate(self.rgbas, color_to_rgba(color), alpha)
104+
for mob in self.submobjects:
105+
mob.fade_to(color, alpha, family)
106+
self.set_rgba_array_direct(rgbas)
107+
return self
108+
109+
def filter_out(self, condition):
110+
for mob in self.family_members_with_points():
111+
to_keep = ~np.apply_along_axis(condition, 1, mob.get_points())
112+
for key in mob.data:
113+
mob.data[key] = mob.data[key][to_keep]
114+
return self
115+
116+
def sort_points(self, function=lambda p: p[0]):
117+
"""
118+
function is any map from R^3 to R
119+
"""
120+
for mob in self.family_members_with_points():
121+
indices = np.argsort(np.apply_along_axis(function, 1, mob.get_points()))
122+
for key in mob.data:
123+
mob.data[key] = mob.data[key][indices]
124+
return self
125+
126+
def ingest_submobjects(self):
127+
for key in self.data:
128+
self.data[key] = np.vstack([sm.data[key] for sm in self.get_family()])
129+
return self
130+
131+
def point_from_proportion(self, alpha):
132+
index = alpha * (self.get_num_points() - 1)
133+
return self.get_points()[int(index)]
134+
135+
def pointwise_become_partial(self, pmobject, a, b):
136+
lower_index = int(a * pmobject.get_num_points())
137+
upper_index = int(b * pmobject.get_num_points())
138+
for key in self.data:
139+
self.data[key] = pmobject.data[key][lower_index:upper_index]
140+
return self
141+
142+
def get_shader_data(self):
143+
shader_data = np.zeros(len(self.get_points()), dtype=self.shader_dtype)
144+
self.read_data_to_shader(shader_data, "point", "points")
145+
self.read_data_to_shader(shader_data, "color", "rgbas")
146+
return shader_data
147+
148+
149+
class OpenGLPGroup(OpenGLPMobject):
150+
def __init__(self, *pmobs, **kwargs):
151+
if not all([isinstance(m, OpenGLPMobject) for m in pmobs]):
152+
raise Exception("All submobjects must be of type OpenglPMObject")
153+
super().__init__(**kwargs)
154+
self.add(*pmobs)
155+
156+
def fade_to(self, color, alpha, family=True):
157+
if family:
158+
for mob in self.submobjects:
159+
mob.fade_to(color, alpha, family)
160+
161+
162+
class OpenGLPMPoint(OpenGLPMobject):
163+
def __init__(self, location=ORIGIN, stroke_width=4.0, color=BLACK, **kwargs):
164+
self.location = location
165+
super().__init__(stroke_width=stroke_width, **kwargs)
166+
167+
def init_points(self):
168+
self.points = np.array([self.location], dtype=np.float32)

0 commit comments

Comments
 (0)