Skip to content

Commit e882b18

Browse files
authored
OpenGL compatibility via metaclass: TracedPath, ParametricFunction, Brace, VGroup (#1572)
* TracedPath complete * ParametricFunction complete * Brace complete * VGroup and VDict complete * isort fix * type fix * Revert "VGroup and VDict complete" This reverts commit e0cc022. * VGroup complete * fix super * Adapt to the new metaclass approach * remove self.basecls
1 parent e5cc89f commit e882b18

File tree

4 files changed

+32
-16
lines changed

4 files changed

+32
-16
lines changed

manim/mobject/changing.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
import numpy as np
66

7+
from .._config import config
78
from ..constants import *
89
from ..mobject.types.vectorized_mobject import VGroup, VMobject
910
from ..utils.color import BLUE_B, BLUE_D, BLUE_E, GREY_BROWN, WHITE
1011
from ..utils.rate_functions import smooth
12+
from .opengl_compatibility import ConvertToOpenGL
1113

1214

1315
class AnimatedBoundary(VGroup):
@@ -89,7 +91,7 @@ def full_family_become_partial(self, mob1, mob2, a, b):
8991
return self
9092

9193

92-
class TracedPath(VMobject):
94+
class TracedPath(VMobject, metaclass=ConvertToOpenGL):
9395
"""Traces the path of a point returned by a function call.
9496
9597
Examples
@@ -123,15 +125,18 @@ def __init__(
123125

124126
def update_path(self):
125127
new_point = self.traced_point_func()
126-
if self.has_no_points():
128+
if not self.has_points():
127129
self.start_new_path(new_point)
128130
self.add_line_to(new_point)
129131
else:
130132
# Set the end to be the new point
131-
self.points[-1] = new_point
133+
self.get_points()[-1] = new_point
132134

133135
# Second to last point
134-
nppcc = self.n_points_per_cubic_curve
135-
dist = np.linalg.norm(new_point - self.points[-nppcc])
136+
if config["renderer"] == "opengl":
137+
nppcc = self.n_points_per_curve
138+
else:
139+
nppcc = self.n_points_per_cubic_curve
140+
dist = np.linalg.norm(new_point - self.get_points()[-nppcc])
136141
if dist >= self.min_distance_to_new_point:
137142
self.add_line_to(new_point)

manim/mobject/functions.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
from ..constants import *
1010
from ..mobject.types.vectorized_mobject import VMobject
1111
from ..utils.color import YELLOW
12+
from .opengl_compatibility import ConvertToOpenGL
1213

1314

14-
class ParametricFunction(VMobject):
15+
class ParametricFunction(VMobject, metaclass=ConvertToOpenGL):
1516
"""A parametric curve.
1617
1718
Examples
@@ -65,7 +66,7 @@ def __init__(
6566
self.use_smoothing = use_smoothing
6667
self.t_min, self.t_max, self.t_step = t_range
6768

68-
VMobject.__init__(self, **kwargs)
69+
super().__init__(**kwargs)
6970

7071
def get_function(self):
7172
return self.function
@@ -98,6 +99,8 @@ def generate_points(self):
9899
self.make_smooth()
99100
return self
100101

102+
init_points = generate_points
103+
101104

102105
class FunctionGraph(ParametricFunction):
103106
def __init__(self, function, x_range=None, color=YELLOW, **kwargs):

manim/mobject/svg/brace.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66

77
import numpy as np
88

9+
from manim._config import config
10+
from manim.mobject.opengl_compatibility import ConvertToOpenGL
11+
912
from ...animation.composition import AnimationGroup
1013
from ...animation.fading import FadeIn
1114
from ...animation.growing import GrowFromCenter
1215
from ...constants import *
1316
from ...mobject.geometry import Arc, Line
1417
from ...mobject.svg.svg_path import SVGPathMobject
1518
from ...mobject.svg.tex_mobject import MathTex, Tex
19+
from ...mobject.types.opengl_vectorized_mobject import OpenGLVMobject
1620
from ...mobject.types.vectorized_mobject import VMobject
1721
from ...utils.color import BLACK
1822

@@ -119,14 +123,17 @@ def get_tex(self, *tex, **kwargs):
119123

120124
def get_tip(self):
121125
# Returns the position of the seventh point in the path, which is the tip.
122-
return self.points[28] # = 7*4
126+
if config["renderer"] == "opengl":
127+
return self.get_points()[34]
128+
129+
return self.get_points()[28] # = 7*4
123130

124131
def get_direction(self):
125132
vect = self.get_tip() - self.get_center()
126133
return vect / np.linalg.norm(vect)
127134

128135

129-
class BraceLabel(VMobject):
136+
class BraceLabel(VMobject, metaclass=ConvertToOpenGL):
130137
def __init__(
131138
self,
132139
obj,
@@ -138,10 +145,11 @@ def __init__(
138145
):
139146
self.label_constructor = label_constructor
140147
self.label_scale = label_scale
141-
VMobject.__init__(self, **kwargs)
148+
super().__init__(**kwargs)
149+
142150
self.brace_direction = brace_direction
143151
if isinstance(obj, list):
144-
obj = VMobject(*obj)
152+
obj = self.get_group_class()(*obj)
145153
self.brace = Brace(obj, brace_direction, **kwargs)
146154

147155
if isinstance(text, tuple) or isinstance(text, list):
@@ -159,7 +167,7 @@ def creation_anim(self, label_anim=FadeIn, brace_anim=GrowFromCenter):
159167

160168
def shift_brace(self, obj, **kwargs):
161169
if isinstance(obj, list):
162-
obj = VMobject(*obj)
170+
obj = self.get_group_class()(*obj)
163171
self.brace = Brace(obj, self.brace_direction, **kwargs)
164172
self.brace.put_at_tip(self.label)
165173
self.submobjects[0] = self.brace

manim/mobject/types/vectorized_mobject.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,7 +1468,7 @@ def force_direction(self, target_direction):
14681468
return self
14691469

14701470

1471-
class VGroup(VMobject):
1471+
class VGroup(VMobject, metaclass=ConvertToOpenGL):
14721472
"""A group of vectorized mobjects.
14731473
14741474
This can be used to group multiple :class:`~.VMobject` instances together
@@ -1520,7 +1520,7 @@ def construct(self):
15201520
"""
15211521

15221522
def __init__(self, *vmobjects, **kwargs):
1523-
VMobject.__init__(self, **kwargs)
1523+
super().__init__(**kwargs)
15241524
self.add(*vmobjects)
15251525

15261526
def __repr__(self):
@@ -1585,7 +1585,7 @@ def construct(self):
15851585
(gr-circle_red).animate.shift(RIGHT)
15861586
)
15871587
"""
1588-
if not all(isinstance(m, VMobject) for m in vmobjects):
1588+
if not all(isinstance(m, (VMobject, OpenGLVMobject)) for m in vmobjects):
15891589
raise TypeError("All submobjects must be of type VMobject")
15901590
return super().add(*vmobjects)
15911591

@@ -1625,7 +1625,7 @@ def __setitem__(self, key: int, value: Union[VMobject, typing.Sequence[VMobject]
16251625
>>> new_obj = VMobject()
16261626
>>> vgroup[0] = new_obj
16271627
"""
1628-
if not all(isinstance(m, VMobject) for m in value):
1628+
if not all(isinstance(m, (VMobject, OpenGLVMobject)) for m in value):
16291629
raise TypeError("All submobjects must be of type VMobject")
16301630
self.submobjects[key] = value
16311631

0 commit comments

Comments
 (0)