Skip to content

Commit a1d40e1

Browse files
authored
Vectorized plot constant (#1524)
Allow constant expression in vectorized Plot3D. This change affects only the vectorized version; added a test for constant expression to verify that it works already for Plot3D classic.
1 parent c88706e commit a1d40e1

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

mathics/eval/drawing/plot3d_vectorized.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,21 @@ def eval_Plot3D(
2323
):
2424
graphics = GraphicsGenerator(dim=3)
2525

26-
for function in plot_options.functions:
27-
# pull out plot options
28-
_, xmin, xmax = plot_options.ranges[0]
29-
_, ymin, ymax = plot_options.ranges[1]
30-
nx, ny = plot_options.plotpoints
31-
names = [strip_context(str(range[0])) for range in plot_options.ranges]
26+
# pull out plot options
27+
_, xmin, xmax = plot_options.ranges[0]
28+
_, ymin, ymax = plot_options.ranges[1]
29+
nx, ny = plot_options.plotpoints
30+
names = [strip_context(str(range[0])) for range in plot_options.ranges]
31+
32+
# compute (nx, ny) grids of xs and ys for corresponding vertexes
33+
xs = np.linspace(xmin, xmax, nx)
34+
ys = np.linspace(ymin, ymax, ny)
35+
xs, ys = np.meshgrid(xs, ys)
3236

37+
for function in plot_options.functions:
3338
with Timer("compile"):
3439
function = plot_compile(evaluation, function, names)
3540

36-
# compute (nx, ny) grids of xs and ys for corresponding vertexes
37-
xs = np.linspace(xmin, xmax, nx)
38-
ys = np.linspace(ymin, ymax, ny)
39-
xs, ys = np.meshgrid(xs, ys)
40-
4141
# compute zs from xs and ys using compiled function
4242
with Timer("compute zs"):
4343
zs = function(**{str(names[0]): xs, str(names[1]): ys})
@@ -49,6 +49,10 @@ def eval_Plot3D(
4949
# assert np.all(np.isreal(zs)), "array contains complex values"
5050
zs = np.real(zs)
5151

52+
# if it's a constant, make it a full array
53+
if isinstance(zs, (float, int, complex)):
54+
zs = np.full(xs.shape, zs)
55+
5256
with Timer("stack"):
5357
# (nx*ny, 3) array of points, to be indexed by quads
5458
xyzs = np.stack([xs, ys, zs]).transpose(1, 2, 0).reshape(-1, 3)

test/builtin/drawing/test_plot.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def test__listplot():
6666
"-Graphics3D-",
6767
None,
6868
),
69+
("Plot3D[1, {x,-1,1}, {y,-1,1}]", None, "-Graphics3D-", None),
6970
(
7071
"Plot3D[]",
7172
("Plot3D called with 0 arguments; 3 arguments are expected.",),

0 commit comments

Comments
 (0)