Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 157 additions & 80 deletions src/e3sm_quickview/plugins/eam_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,27 @@
vtkPoints,
)
from vtkmodules.vtkCommonDataModel import (
vtkPolyData,
vtkCellArray,
vtkDataSetAttributes,
vtkPlane,
vtkPolyData,
)
from vtkmodules.vtkCommonTransforms import vtkTransform
from vtkmodules.vtkFiltersCore import (
vtkAppendFilter,
vtkCellCenters,
vtkGenerateIds,
vtkPolyDataToUnstructuredGrid,
)
from vtkmodules.vtkFiltersGeneral import (
vtkTransformFilter,
vtkTableBasedClipDataSet,
vtkTransformFilter,
)
from vtkmodules.vtkFiltersPoints import (
vtkExtractSurface
)
from vtkmodules.vtkFiltersGeometry import (
vtkGeometryFilter
)

try:
Expand Down Expand Up @@ -255,6 +264,11 @@ def __init__(self):
self.__Dims = -1
self.project = 0
self.translate = False
self.cached_points = None

def __del__(self):
if self.cached_points:
self.cached_points.Unregister()

def SetTranslation(self, translate):
if self.translate != translate:
Expand All @@ -267,55 +281,58 @@ def SetProjection(self, project):
self.Modified()

def RequestData(self, request, inInfo, outInfo):
if self.project == 0:
return 1
inData = self.GetInputData(inInfo, 0, 0)
outData = self.GetOutputData(outInfo, 0)
if inData.IsA("vtkPolyData"):
afilter = vtkAppendFilter()
afilter.AddInputData(inData)
afilter.Update()
outData.DeepCopy(afilter.GetOutput())
outData.ShallowCopy(afilter.GetOutput())
else:
outData.DeepCopy(inData)

if self.project == 0:
return 1

inWrap = dsa.WrapDataObject(inData)
outWrap = dsa.WrapDataObject(outData)
inPoints = np.array(inWrap.Points)

flat = inPoints.flatten()
x = flat[0::3] - 180.0 if self.translate else flat[0::3]
y = flat[1::3]

try:
# Use proj4 string for WGS84 instead of EPSG code to avoid database dependency
latlon = Proj(proj="latlong", datum="WGS84")
if self.project == 1:
proj = Proj(proj="robin")
elif self.project == 2:
proj = Proj(proj="moll")
else:
# Should not reach here, but return without transformation
outData.ShallowCopy(inData)
if self.cached_points and \
self.cached_points.GetMTime() >= inData.GetPoints().GetMTime():
outData.SetPoints(self.cached_points)
else:
# we modify the points, so copy them
out_points_vtk = vtkPoints()
out_points_vtk.DeepCopy(inData.GetPoints())
outData.SetPoints(out_points_vtk)
out_points_np = outData.points

flat = out_points_np.flatten()
x = flat[0::3] - 180.0 if self.translate else flat[0::3]
y = flat[1::3]

try:
# Use proj4 string for WGS84 instead of EPSG code to avoid database dependency
latlon = Proj(proj="latlong", datum="WGS84")
if self.project == 1:
proj = Proj(proj="robin")
elif self.project == 2:
proj = Proj(proj="moll")
else:
# Should not reach here, but return without transformation
return 1

xformer = Transformer.from_proj(latlon, proj, always_xy=True)
res = xformer.transform(x, y)
except Exception as e:
print(f"Projection error: {e}")
# If projection fails, return without modifying coordinates
return 1

xformer = Transformer.from_proj(latlon, proj, always_xy=True)
res = xformer.transform(x, y)
except Exception as e:
print(f"Projection error: {e}")
# If projection fails, return without modifying coordinates
return 1
flat[0::3] = np.array(res[0])
flat[1::3] = np.array(res[1])

outPoints = flat.reshape(inPoints.shape)
_coords = numpy_support.numpy_to_vtk(
outPoints, deep=True, array_type=vtkConstants.VTK_FLOAT
)
vtk_coords = vtkPoints()
vtk_coords.SetData(_coords)
outWrap.SetPoints(vtk_coords)

flat[0::3] = np.array(res[0])
flat[1::3] = np.array(res[1])

outPoints = flat.reshape(out_points_np.shape)
_coords = numpy_support.numpy_to_vtk(outPoints, deep=True)
outData.GetPoints().SetData(_coords)
if self.cached_points:
self.cached_points.Unregister(self)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than Unregister, should you just set self.cached_points = None?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I need to do this at all? If the python object is deleted, the c++ object should be unregistered as well, isn't it?

self.cached_points = out_points_vtk
self.cached_points.Register(self)
return 1


Expand Down Expand Up @@ -415,15 +432,15 @@ def RequestData(self, request, inInfo, outInfo):
@smdomain.datatype(dataTypes=["vtkPolyData"], composite_data_supported=False)
@smproperty.xml(
"""
<DoubleVectorProperty name="Longitude Range"
command="SetLongitudeRange"
<DoubleVectorProperty name="Trim Longitude"
command="SetTrimLongitude"
number_of_elements="2"
default_values="-180 180">
default_values="0 0">
</DoubleVectorProperty>
<DoubleVectorProperty name="Latitude Range"
command="SetLatitudeRange"
<DoubleVectorProperty name="Trim Latitude"
command="SetTrimLatitude"
number_of_elements="2"
default_values="-90 90">
default_values="0 0">
</DoubleVectorProperty>
"""
)
Expand All @@ -432,44 +449,99 @@ def __init__(self):
super().__init__(
nInputPorts=1, nOutputPorts=1, outputType="vtkUnstructuredGrid"
)
self.longrange = [-180.0, 180.0]
self.latrange = [-90.0, 90.0]

def SetLongitudeRange(self, min, max):
if self.longrange[0] != min or self.longrange[1] != max:
self.longrange = [min, max]
self.trim_lon = [0, 0]
self.trim_lat = [0, 0]
self.cached_cell_centers = None
self.cached_ghosts = None

def __del__(self):
if self.cached_cell_centers:
self.cached_cell_centers.Unregister(self)
if self.cached_ghosts:
self.cached_ghosts.Unregister(self)

def SetTrimLongitude(self, left, right):
if left < 0 or left > 180 or right < 0 or right > 180:
print_error(f"SetTrimLongitude called with parameters outside [0, 180]: {left=}, {right=}")
return
if self.trim_lon[0] != left or self.trim_lon[1] != right:
self.trim_lon = [left, right]
self.Modified()

def SetLatitudeRange(self, min, max):
if self.latrange[0] != min or self.latrange[1] != max:
self.latrange = [min, max]
def SetTrimLatitude(self, left, right):
if left < 0 or left > 90 or right < 0 or right > 90:
print_error(f"SetTrimLatitude called with parameters outside [0, 180]: {left=}, {right=}")
return
if self.trim_lat[0] != left or self.trim_lat[1] != right:
self.trim_lat = [left, right]
self.Modified()

def RequestData(self, request, inInfo, outInfo):
inData = self.GetInputData(inInfo, 0, 0)
outData = self.GetOutputData(outInfo, 0)
if self.longrange == [-180.0, 180] and self.latrange == [-90, 90]:
if self.trim_lon == [0, 0] and self.trim_lat == [0, 0]:
outData.ShallowCopy(inData)
return 1

box = vtkPVBox()
box.SetReferenceBounds(
self.longrange[0],
self.longrange[1],
self.latrange[0],
self.latrange[1],
-1.0,
1.0,
)
box.SetUseReferenceBounds(True)
extract = vtkPVClipDataSet()
extract.SetClipFunction(box)
extract.InsideOutOn()
extract.ExactBoxClipOn()
extract.SetInputData(inData)
extract.Update()
outData.ShallowCopy(inData)
if self.cached_cell_centers and self.cached_cell_centers.GetMTime() >= max(
inData.GetPoints().GetMTime(), inData.GetCells().GetMTime()
):
cell_centers = self.cached_cell_centers
else:
# convert to polydata, as vtkCellCenters only works on polydata
# import pdb;pdb.set_trace()
to_poly = vtkGeometryFilter()
to_poly.SetInputData(inData)

# get cell centers
compute_centers = vtkCellCenters()
compute_centers.SetInputConnection(to_poly.GetOutputPort())
compute_centers.Update()
cell_centers = compute_centers.GetOutput().GetPoints().GetData()
if self.cached_cell_centers:
self.cached_cell_centers.Unregister(self)
self.cached_cell_centers = cell_centers
self.cached_cell_centers.Register(self)

# get the numpy array for cell centers
cc = numpy_support.vtk_to_numpy(cell_centers)

if self.cached_ghosts and self.cached_ghosts.GetMTime() >= max(
self.GetMTime(), inData.GetPoints().GetMTime(), cell_centers.GetMTime()
):
ghost = self.cached_ghosts
else:
# import pdb;pdb.set_trace()
# compute the new bounds by trimming the inData bounds
bounds = list(inData.GetBounds())
bounds[0] = bounds[0] + self.trim_lon[0]
bounds[1] = bounds[1] - self.trim_lon[1]
bounds[2] = bounds[2] + self.trim_lat[0]
bounds[3] = bounds[3] - self.trim_lat[1]

# add hidden cells based on bounds
outside_mask = (
(cc[:, 0] < bounds[0])
| (cc[:, 0] > bounds[1])
| (cc[:, 1] < bounds[2])
| (cc[:, 1] > bounds[3])
)

# Create ghost array (0 = visible, HIDDENCELL = invisible)
ghost_np = np.where(
outside_mask, vtkDataSetAttributes.HIDDENCELL, 0
).astype(np.uint8)

# Convert to VTK and add to output
ghost = numpy_support.numpy_to_vtk(ghost_np)
ghost.SetName(vtkDataSetAttributes.GhostArrayName())
if self.cached_ghosts:
self.cached_ghosts.Unregister(self)
self.cached_ghosts = ghost
self.cached_ghosts.Register(self)
outData.GetCellData().AddArray(ghost)

outData.ShallowCopy(extract.GetOutput())
return 1


Expand Down Expand Up @@ -513,6 +585,10 @@ def __init__(self):
self._center_meridian = 0
self._cached_output = None

def __del__(self):
if self._cached_output:
self._cached_output.Unregister(self)

def SetMeridian(self, meridian_):
"""
Specifies the central meridian (longitude in the middle of the map)
Expand All @@ -535,14 +611,12 @@ def GetMeridian(self):

def RequestData(self, request, inInfo, outInfo):
inData = self.GetInputData(inInfo, 0, 0)
inPoints = inData.GetPoints()
inCellArray = inData.GetCells()

outData = self.GetOutputData(outInfo, 0)
if (
self._cached_output
and self._cached_output.GetMTime() > inPoints.GetMTime()
and self._cached_output.GetMTime() > inCellArray.GetMTime()
and self._cached_output.GetPoints().GetMTime() >= inData.GetPoints().GetMTime()
and self._cached_output.GetCells().GetMTime() >= inData.GetCells().GetMTime()
):
# only scalars have been added or removed
cached_cell_data = self._cached_output.GetCellData()
Expand Down Expand Up @@ -606,6 +680,9 @@ def RequestData(self, request, inInfo, outInfo):
append.AddInputData(transform.GetOutput())
append.Update()
outData.ShallowCopy(append.GetOutput())
if self._cached_output:
self._cached_output.Unregister(self)
self._cached_output = outData.NewInstance()
self._cached_output.ShallowCopy(outData)
self._cached_output.Register(self)
return 1
Loading