diff --git a/manim/mobject/opengl/opengl_surface.py b/manim/mobject/opengl/opengl_surface.py index dec54c98ee..a020d1bcff 100644 --- a/manim/mobject/opengl/opengl_surface.py +++ b/manim/mobject/opengl/opengl_surface.py @@ -97,6 +97,9 @@ def __init__( # For du and dv steps. Much smaller and numerical error # can crop up in the shaders. self.epsilon = epsilon + self._colorscale_points = None + self._colorscale_min = None + self._colorscale_max = None self.triangle_indices = None super().__init__( @@ -297,6 +300,13 @@ def _get_color_by_value(self, s_points): List A list of colors matching the vertex inputs. """ + if self._colorscale_points is None: + self._colorscale_points = { + i: self.axes.point_to_coords(point)[self.colorscale_axis] + for i, point in enumerate(s_points) + } + self._colorscale_min = self.axes.z_range[0] + self._colorscale_max = self.axes.z_range[1] if type(self.colorscale[0]) in (list, tuple): new_colors, pivots = [ [i for i, j in self.colorscale], @@ -305,8 +315,8 @@ def _get_color_by_value(self, s_points): else: new_colors = self.colorscale - pivot_min = self.axes.z_range[0] - pivot_max = self.axes.z_range[1] + pivot_min = self._colorscale_min + pivot_max = self._colorscale_max pivot_frequency = (pivot_max - pivot_min) / (len(new_colors) - 1) pivots = np.arange( start=pivot_min, @@ -315,8 +325,8 @@ def _get_color_by_value(self, s_points): ) return_colors = [] - for point in s_points: - axis_value = self.axes.point_to_coords(point)[self.colorscale_axis] + for i, point in enumerate(s_points): + axis_value = self._colorscale_points[i] if axis_value <= pivots[0]: return_colors.append(color_to_rgba(new_colors[0], self.opacity)) elif axis_value >= pivots[-1]: