Skip to content

Commit 783bdde

Browse files
Refactored coordinate_systems.py, fixed bugs, added NumberPlane test (#1092)
* Add function to update a dict recursively with other dicts * Refactor Axes and NumberPlane * Fix NumberLine's rotation * Rotate NumberLine just before adding numbers * More refactoring * Add graphical unit test to NumberPlane * More cleanup and fix range attributes issue(#840) * Remove overlooked print statements and run black * Make graphical unit test of NumberPlane more robust * Shift z_axis to XY-plane's origin and minor refactor in its rotation * Apply review suggestions
1 parent 292e49d commit 783bdde

File tree

5 files changed

+114
-73
lines changed

5 files changed

+114
-73
lines changed

manim/mobject/coordinate_systems.py

Lines changed: 74 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ..mobject.number_line import NumberLine
1515
from ..mobject.svg.tex_mobject import MathTex
1616
from ..mobject.types.vectorized_mobject import VGroup
17-
from ..utils.config_ops import merge_dicts_recursively
17+
from ..utils.config_ops import merge_dicts_recursively, update_dict_recursively
1818
from ..utils.simple_functions import binary_search
1919
from ..utils.space_ops import angle_of_vector
2020
from ..utils.color import LIGHT_GREY, WHITE, BLUE_D, BLUE
@@ -28,16 +28,12 @@ class CoordinateSystem:
2828
Abstract class for Axes and NumberPlane
2929
"""
3030

31-
def __init__(self, dim=2):
31+
def __init__(self, x_min=None, x_max=None, y_min=None, y_max=None, dim=2):
3232
self.dimension = dim
33-
if not hasattr(self, "x_min"):
34-
self.x_min = -config["frame_x_radius"]
35-
if not hasattr(self, "x_max"):
36-
self.x_max = config["frame_x_radius"]
37-
if not hasattr(self, "y_min"):
38-
self.y_min = -config["frame_y_radius"]
39-
if not hasattr(self, "y_max"):
40-
self.y_max = config["frame_y_radius"]
33+
self.x_min = -config["frame_x_radius"] if x_min is None else x_min
34+
self.x_max = config["frame_x_radius"] if x_max is None else x_max
35+
self.y_min = -config["frame_y_radius"] if y_min is None else y_min
36+
self.y_max = config["frame_y_radius"] if y_max is None else y_max
4137

4238
def coords_to_point(self, *coords):
4339
raise NotImplementedError()
@@ -132,44 +128,56 @@ def input_to_graph_point(self, x, graph):
132128
class Axes(VGroup, CoordinateSystem):
133129
def __init__(
134130
self,
131+
x_min=None,
132+
x_max=None,
133+
y_min=None,
134+
y_max=None,
135135
axis_config=None,
136136
x_axis_config=None,
137137
y_axis_config=None,
138138
center_point=ORIGIN,
139139
**kwargs
140140
):
141-
if axis_config is None:
142-
axis_config = {
143-
"color": LIGHT_GREY,
144-
"include_tip": True,
145-
"exclude_zero_from_default_numbers": True,
146-
}
147-
if y_axis_config is None:
148-
y_axis_config = {"label_direction": LEFT, "rotation": 90 * DEGREES}
149-
self.axis_config = axis_config
150-
if x_axis_config is None:
151-
x_axis_config = {}
152-
self.x_axis_config = x_axis_config
153-
self.y_axis_config = y_axis_config
154-
self.center_point = center_point
155-
CoordinateSystem.__init__(self)
141+
CoordinateSystem.__init__(
142+
self, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max
143+
)
156144
VGroup.__init__(self, **kwargs)
157-
self.x_axis = self.create_axis(self.x_min, self.x_max, self.x_axis_config)
158-
self.y_axis = self.create_axis(self.y_min, self.y_max, self.y_axis_config)
145+
146+
self.axis_config = {
147+
"color": LIGHT_GREY,
148+
"include_tip": True,
149+
"exclude_zero_from_default_numbers": True,
150+
}
151+
self.x_axis_config = {"x_min": self.x_min, "x_max": self.x_max}
152+
self.y_axis_config = {
153+
"x_min": self.y_min,
154+
"x_max": self.y_max,
155+
"label_direction": LEFT,
156+
"rotation": 90 * DEGREES,
157+
}
158+
159+
self.update_default_configs(
160+
(self.axis_config, self.x_axis_config, self.y_axis_config),
161+
(axis_config, x_axis_config, y_axis_config),
162+
)
163+
self.center_point = center_point
164+
self.x_axis = self.create_axis(self.x_axis_config)
165+
self.y_axis = self.create_axis(self.y_axis_config)
159166
# Add as a separate group in case various other
160167
# mobjects are added to self, as for example in
161168
# NumberPlane below
162169
self.axes = VGroup(self.x_axis, self.y_axis, dim=self.dim)
163170
self.add(*self.axes)
164171
self.shift(self.center_point)
165172

166-
def create_axis(self, min_val, max_val, axis_config):
167-
new_config = merge_dicts_recursively(
168-
self.axis_config,
169-
{"x_min": min_val, "x_max": max_val},
170-
axis_config,
171-
)
172-
return NumberLine(**new_config)
173+
@staticmethod
174+
def update_default_configs(default_configs, passed_configs):
175+
for default_config, passed_config in zip(default_configs, passed_configs):
176+
if passed_config is not None:
177+
update_dict_recursively(default_config, passed_config)
178+
179+
def create_axis(self, axis_config):
180+
return NumberLine(**merge_dicts_recursively(self.axis_config, axis_config))
173181

174182
def coords_to_point(self, *coords):
175183
origin = self.x_axis.number_to_point(0)
@@ -209,37 +217,33 @@ def add_coordinates(self, x_vals=None, y_vals=None):
209217
class ThreeDAxes(Axes):
210218
def __init__(
211219
self,
212-
z_axis_config=None,
213-
z_min=-3.5,
214-
z_max=3.5,
215220
x_min=-5.5,
216221
x_max=5.5,
217222
y_min=-5.5,
218223
y_max=5.5,
224+
z_min=-3.5,
225+
z_max=3.5,
226+
z_axis_config=None,
219227
z_normal=DOWN,
220228
num_axis_pieces=20,
221229
light_source=9 * DOWN + 7 * LEFT + 10 * OUT,
222230
**kwargs
223231
):
224-
if z_axis_config is None:
225-
z_axis_config = {}
226-
self.z_axis_config = z_axis_config
232+
Axes.__init__(
233+
self, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max, **kwargs
234+
)
227235
self.z_min = z_min
228236
self.z_max = z_max
237+
self.z_axis_config = {"x_min": self.z_min, "x_max": self.z_max}
238+
self.update_default_configs((self.z_axis_config,), (z_axis_config,))
229239
self.z_normal = z_normal
230240
self.num_axis_pieces = num_axis_pieces
231241
self.light_source = light_source
232-
self.x_min = x_min
233-
self.x_max = x_max
234-
self.y_min = y_min
235-
self.y_max = y_max
236-
Axes.__init__(self, **kwargs)
237242
self.dimension = 3
238-
z_axis = self.z_axis = self.create_axis(
239-
self.z_min, self.z_max, self.z_axis_config
240-
)
241-
z_axis.rotate(-np.pi / 2, UP, about_point=ORIGIN)
242-
z_axis.rotate(angle_of_vector(self.z_normal), OUT, about_point=ORIGIN)
243+
z_axis = self.z_axis = self.create_axis(self.z_axis_config)
244+
z_axis.shift(self.center_point)
245+
z_axis.rotate_about_zero(-np.pi / 2, UP)
246+
z_axis.rotate_about_zero(angle_of_vector(self.z_normal))
243247
self.axes.add(z_axis)
244248
self.add(z_axis)
245249

@@ -281,28 +285,26 @@ def __init__(
281285
make_smooth_after_applying_functions=True,
282286
**kwargs
283287
):
284-
if axis_config is None:
285-
axis_config = {
286-
"stroke_color": WHITE,
287-
"stroke_width": 2,
288-
"include_ticks": False,
289-
"include_tip": False,
290-
"line_to_number_buff": SMALL_BUFF,
291-
"label_direction": DR,
292-
"number_scale_val": 0.5,
293-
}
294-
self.axis_config = axis_config
295-
if y_axis_config is None:
296-
y_axis_config = {"label_direction": DR}
297-
self.y_axis_config = y_axis_config
298-
299-
if background_line_style is None:
300-
background_line_style = {
301-
"stroke_color": BLUE_D,
302-
"stroke_width": 2,
303-
"stroke_opacity": 1,
304-
}
305-
self.background_line_style = background_line_style
288+
self.axis_config = {
289+
"stroke_color": WHITE,
290+
"stroke_width": 2,
291+
"include_ticks": False,
292+
"include_tip": False,
293+
"line_to_number_buff": SMALL_BUFF,
294+
"label_direction": DR,
295+
"number_scale_val": 0.5,
296+
}
297+
self.y_axis_config = {"label_direction": DR}
298+
self.background_line_style = {
299+
"stroke_color": BLUE_D,
300+
"stroke_width": 2,
301+
"stroke_opacity": 1,
302+
}
303+
304+
self.update_default_configs(
305+
(self.axis_config, self.y_axis_config, self.background_line_style),
306+
(axis_config, y_axis_config, background_line_style),
307+
)
306308

307309
# Defaults to a faded version of line_config
308310
self.faded_line_style = faded_line_style

manim/mobject/number_line.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,16 @@ def __init__(
9494
self.add_tip()
9595
if self.include_ticks:
9696
self.add_tick_marks()
97-
self.rotate(self.rotation)
97+
self.rotate_about_zero(self.rotation)
9898
if self.include_numbers:
9999
self.add_numbers()
100100

101+
def rotate_about_zero(self, angle, axis=OUT, **kwargs):
102+
return self.rotate_about_number(0, angle, axis, **kwargs)
103+
104+
def rotate_about_number(self, number, angle, axis=OUT, **kwargs):
105+
return self.rotate(angle, axis, about_point=self.n2p(number), **kwargs)
106+
101107
def init_leftmost_tick(self):
102108
if self.leftmost_tick is None:
103109
self.leftmost_tick = op.mul(

manim/utils/config_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
__all__ = [
44
"merge_dicts_recursively",
5+
"update_dict_recursively",
56
"DictAsObject",
67
]
78

@@ -29,6 +30,11 @@ def merge_dicts_recursively(*dicts):
2930
return result
3031

3132

33+
def update_dict_recursively(current_dict, *others):
34+
updated_dict = merge_dicts_recursively(current_dict, *others)
35+
current_dict.update(updated_dict)
36+
37+
3238
# Occasionally convenient in order to write dict.x instead of more laborious
3339
# (and less in keeping with all other attr accesses) dict["x"]
3440

Binary file not shown.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
3+
from manim import *
4+
from ..utils.testing_utils import get_scenes_to_test
5+
from ..utils.GraphicalUnitTester import GraphicalUnitTester
6+
7+
8+
class NumberPlaneTest(Scene):
9+
def construct(self):
10+
plane = NumberPlane(
11+
axis_config={"include_numbers": True, "x_min": -8},
12+
x_min=-4,
13+
x_max=6,
14+
x_axis_config={"unit_size": 1.2},
15+
y_min=-2,
16+
y_axis_config={"x_max": 5, "width": 6, "label_direction": UL},
17+
center_point=2 * DL,
18+
)
19+
self.play(Animation(plane))
20+
21+
22+
MODULE_NAME = "coordinate_systems"
23+
24+
25+
@pytest.mark.parametrize("scene_to_test", get_scenes_to_test(__name__), indirect=False)
26+
def test_scene(scene_to_test, tmpdir, show_diff):
27+
GraphicalUnitTester(scene_to_test[1], MODULE_NAME, tmpdir).test(show_diff=show_diff)

0 commit comments

Comments
 (0)