Skip to content

Commit d6ba230

Browse files
authored
[ENH] Add PyTorch backend support for dual contouring and octree operations (#21)
# [ENH] Adding PyTorch GPU Support This PR adds support for PyTorch GPU acceleration by: - Implementing a PyTorch backend for tensor operations throughout the codebase - Adding GPU device detection and configuration in the backend tensor module - Creating PyTorch-compatible implementations of key functions like `packbits`, `to_numpy`, and matrix operations - Adapting octree generation and dual contouring algorithms to work with PyTorch tensors - Ensuring proper tensor conversion between CPU and GPU when needed - Implementing tensor-specific operations for both backends to maintain compatibility - Adding proper memory management for contiguous arrays in PyTorch The implementation ensures that operations can run on either CPU or GPU when using the PyTorch backend, with appropriate error handling when CUDA is not available but GPU is requested.
2 parents 7935462 + c6925dd commit d6ba230

File tree

11 files changed

+424
-134
lines changed

11 files changed

+424
-134
lines changed

gempy_engine/API/dual_contouring/_dual_contouring.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy
12
import warnings
23
from typing import List
34

@@ -90,18 +91,19 @@ def compute_dual_contouring(dc_data_per_stack: DualContouringData, left_right_co
9091
tree_depth = dc_data_per_surface.tree_depth,
9192
voxel_normals = voxel_normal
9293
)
93-
indices = np.vstack(indices)
94+
indices = BackendTensor.t.concatenate(indices, axis=0)
9495

9596
# @on
9697
vertices_numpy = BackendTensor.t.to_numpy(vertices)
98+
indices_numpy = BackendTensor.t.to_numpy(indices)
9799

98100
if TRIMESH_LAST_PASS := True:
99-
vertices_numpy, indices = _last_pass(vertices_numpy, indices)
101+
vertices_numpy, indices_numpy = _last_pass(vertices_numpy, indices_numpy)
100102

101103
stack_meshes.append(
102104
DualContouringMesh(
103105
vertices_numpy,
104-
indices,
106+
indices_numpy,
105107
dc_data_per_stack
106108
)
107109
)

gempy_engine/API/interp_single/_octree_generation.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def interpolate_on_octree(interpolation_input: InterpolationInput, options: Inte
5858
return next_octree_level
5959

6060

61-
def _generate_corners(regular_grid: RegularGrid, level=1) -> np.ndarray:
61+
def _generate_corners_DEP(regular_grid: RegularGrid, level=1) -> np.ndarray:
6262
if regular_grid is None: raise ValueError("Regular grid is None")
6363

6464
xyz_coord = regular_grid.values
@@ -79,3 +79,41 @@ def stack_left_right(a_edg, d_a):
7979

8080
new_xyz = np.stack((x, y, z)).T
8181
return np.ascontiguousarray(new_xyz)
82+
83+
def _generate_corners(regular_grid: RegularGrid, level=1):
84+
if regular_grid is None:
85+
raise ValueError("Regular grid is None")
86+
87+
# Convert to backend tensors
88+
# xyz_coord = BackendTensor.tfnp.array(regular_grid.values)
89+
# dxdydz = BackendTensor.tfnp.array(regular_grid.dxdydz)
90+
91+
xyz_coord = regular_grid.values
92+
dxdydz = regular_grid.dxdydz
93+
94+
x_coord, y_coord, z_coord = xyz_coord[:, 0], xyz_coord[:, 1], xyz_coord[:, 2]
95+
dx, dy, dz = dxdydz[0], dxdydz[1], dxdydz[2]
96+
97+
def stack_left_right(a_edg, d_a):
98+
left = a_edg - d_a / level / 2
99+
right = a_edg + d_a / level / 2
100+
return BackendTensor.tfnp.stack([left, right], axis=1)
101+
102+
x_ = BackendTensor.tfnp.repeat(stack_left_right(x_coord, dx), 4, axis=1)
103+
x = BackendTensor.tfnp.ravel(x_)
104+
105+
y_temp = BackendTensor.tfnp.repeat(stack_left_right(y_coord, dy), 2, axis=1)
106+
y_ = BackendTensor.tfnp.tile(y_temp, (1, 2))
107+
y = BackendTensor.tfnp.ravel(y_)
108+
109+
z_ = BackendTensor.tfnp.tile(stack_left_right(z_coord, dz), (1, 4))
110+
z = BackendTensor.tfnp.ravel(z_)
111+
112+
new_xyz = BackendTensor.tfnp.stack([x, y, z], axis=1)
113+
114+
# For PyTorch, ensure contiguous memory (equivalent to np.ascontiguousarray)
115+
if BackendTensor.engine_backend == AvailableBackends.PYTORCH:
116+
if hasattr(new_xyz, 'contiguous'):
117+
new_xyz = new_xyz.contiguous()
118+
119+
return new_xyz

gempy_engine/core/backend_tensor.py

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,20 @@ def _change_backend(cls, engine_backend: AvailableBackends, use_pykeops: bool =
107107
if (use_gpu):
108108
cls.use_gpu = True
109109
# cls.tensor_backend_pointer['active_backend'].set_default_device("cuda")
110+
# Check if CUDA is available
111+
if not pytorch_copy.cuda.is_available():
112+
raise RuntimeError("GPU requested but CUDA is not available in PyTorch")
113+
if False:
114+
# Set default device to CUDA
115+
cls.device = pytorch_copy.device("cuda")
116+
pytorch_copy.set_default_device("cuda")
117+
print(f"GPU enabled. Using device: {cls.device}")
118+
print(f"GPU device count: {pytorch_copy.cuda.device_count()}")
119+
print(f"Current GPU device: {pytorch_copy.cuda.current_device()}")
110120
else:
111121
cls.use_gpu = False
112-
122+
cls.device = pytorch_copy.device("cpu")
123+
pytorch_copy.set_default_device("cpu")
113124
case (_):
114125
raise AttributeError(
115126
f"Engine Backend: {engine_backend} cannot be used because the correspondent library"
@@ -134,7 +145,7 @@ def describe_conf(cls):
134145

135146
@classmethod
136147
def _wrap_pytorch_functions(cls):
137-
from torch import sum, repeat_interleave
148+
from torch import sum, repeat_interleave, isclose
138149
import torch
139150

140151
def _sum(tensor, axis=None, dtype=None, keepdims=False):
@@ -155,6 +166,11 @@ def _array(array_like, dtype=None):
155166
if dtype is None: return array_like
156167
else: return array_like.type(dtype)
157168
else:
169+
# Ensure numpy arrays are contiguous before converting to torch tensor
170+
if isinstance(array_like, numpy.ndarray):
171+
if not array_like.flags.c_contiguous:
172+
array_like = numpy.ascontiguousarray(array_like)
173+
158174
return torch.tensor(array_like, dtype=dtype)
159175

160176
def _concatenate(tensors, axis=0, dtype=None):
@@ -167,6 +183,95 @@ def _concatenate(tensors, axis=0, dtype=None):
167183

168184
def _transpose(tensor, axes=None):
169185
return tensor.transpose(axes[0], axes[1])
186+
187+
188+
def _packbits(tensor, axis=None, bitorder="big"):
189+
"""
190+
Pack boolean values into uint8 bytes along the specified axis.
191+
For a (4, n) tensor with axis=0, this packs every 4 bits into nibbles,
192+
then pads to create full bytes.
193+
"""
194+
# Convert to uint8 if boolean
195+
if tensor.dtype == torch.bool:
196+
tensor = tensor.to(torch.uint8)
197+
198+
if axis == 0:
199+
# Pack along axis 0 (rows)
200+
n_rows, n_cols = tensor.shape
201+
n_output_rows = (n_rows + 7) // 8 # Round up to nearest byte boundary
202+
203+
# Pad with zeros if we don't have multiples of 8 rows
204+
if n_rows % 8 != 0:
205+
padding_rows = 8 - (n_rows % 8)
206+
padding = torch.zeros(padding_rows, n_cols, dtype=torch.uint8, device=tensor.device)
207+
tensor = torch.cat([tensor, padding], dim=0)
208+
209+
# Reshape to group every 8 rows together: (n_output_rows, 8, n_cols)
210+
tensor_reshaped = tensor.view(n_output_rows, 8, n_cols)
211+
212+
# Define bit positions (powers of 2)
213+
if bitorder == "little":
214+
# Little endian: LSB first [1, 2, 4, 8, 16, 32, 64, 128]
215+
powers = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128],
216+
dtype=torch.uint8, device=tensor.device).view(1, 8, 1)
217+
else:
218+
# Big endian: MSB first [128, 64, 32, 16, 8, 4, 2, 1]
219+
powers = torch.tensor([128, 64, 32, 16, 8, 4, 2, 1],
220+
dtype=torch.uint8, device=tensor.device).view(1, 8, 1)
221+
222+
# Pack bits: multiply each bit by its power and sum along the 8-bit dimension
223+
packed = (tensor_reshaped * powers).sum(dim=1) # Shape: (n_output_rows, n_cols)
224+
225+
return packed
226+
227+
elif axis == 1:
228+
# Pack along axis 1 (columns)
229+
n_rows, n_cols = tensor.shape
230+
n_output_cols = (n_cols + 7) // 8
231+
232+
# Pad with zeros if needed
233+
if n_cols % 8 != 0:
234+
padding_cols = 8 - (n_cols % 8)
235+
padding = torch.zeros(n_rows, padding_cols, dtype=torch.uint8, device=tensor.device)
236+
tensor = torch.cat([tensor, padding], dim=1)
237+
238+
# Reshape: (n_rows, n_output_cols, 8)
239+
tensor_reshaped = tensor.view(n_rows, n_output_cols, 8)
240+
241+
# Define bit positions
242+
if bitorder == "little":
243+
powers = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128],
244+
dtype=torch.uint8, device=tensor.device).view(1, 1, 8)
245+
else:
246+
powers = torch.tensor([128, 64, 32, 16, 8, 4, 2, 1],
247+
dtype=torch.uint8, device=tensor.device).view(1, 1, 8)
248+
249+
packed = (tensor_reshaped * powers).sum(dim=2) # Shape: (n_rows, n_output_cols)
250+
return packed
251+
252+
else:
253+
raise NotImplementedError(f"packbits not implemented for axis={axis}")
254+
255+
256+
def _to_numpy(tensor):
257+
"""Convert tensor to numpy array, handling GPU tensors properly"""
258+
if hasattr(tensor, 'device') and tensor.device.type == 'cuda':
259+
# Move to CPU first, then detach and convert to numpy
260+
return tensor.cpu().detach().numpy()
261+
elif hasattr(tensor, 'detach'):
262+
# CPU tensor, just detach and convert
263+
return tensor.detach().numpy()
264+
else:
265+
# Not a torch tensor, return as-is
266+
return tensor
267+
268+
def _fill_diagonal(tensor, value):
269+
"""Fill the diagonal of a 2D tensor with the given value"""
270+
if tensor.dim() != 2:
271+
raise ValueError("fill_diagonal only supports 2D tensors")
272+
diagonal_indices = torch.arange(min(tensor.size(0), tensor.size(1)))
273+
tensor[diagonal_indices, diagonal_indices] = value
274+
return tensor
170275

171276
cls.tfnp.sum = _sum
172277
cls.tfnp.repeat = _repeat
@@ -175,7 +280,7 @@ def _transpose(tensor, axes=None):
175280
cls.tfnp.flip = lambda tensor, axis: tensor.flip(axis)
176281
cls.tfnp.hstack = lambda tensors: torch.concat(tensors, dim=1)
177282
cls.tfnp.array = _array
178-
cls.tfnp.to_numpy = lambda tensor: tensor.detach().numpy()
283+
cls.tfnp.to_numpy = _to_numpy
179284
cls.tfnp.min = lambda tensor, axis: tensor.min(axis=axis)[0]
180285
cls.tfnp.max = lambda tensor, axis: tensor.max(axis=axis)[0]
181286
cls.tfnp.rint = lambda tensor: tensor.round().type(torch.int32)
@@ -185,6 +290,17 @@ def _transpose(tensor, axes=None):
185290
cls.tfnp.transpose = _transpose
186291
cls.tfnp.geomspace = lambda start, stop, step: torch.logspace(start, stop, step, base=10)
187292
cls.tfnp.abs = lambda tensor, dtype = None: tensor.abs().type(dtype) if dtype is not None else tensor.abs()
293+
cls.tfnp.tile = lambda tensor, repeats: tensor.repeat(repeats)
294+
cls.tfnp.ravel = lambda tensor: tensor.flatten()
295+
cls.tfnp.packbits = _packbits
296+
cls.tfnp.fill_diagonal = _fill_diagonal
297+
cls.tfnp.isclose = lambda a, b, rtol=1e-05, atol=1e-08, equal_nan=False: isclose(
298+
a,
299+
torch.tensor(b, dtype=a.dtype, device=a.device),
300+
rtol=rtol,
301+
atol=atol,
302+
equal_nan=equal_nan
303+
)
188304

189305
@classmethod
190306
def _wrap_pykeops_functions(cls):

gempy_engine/core/data/interp_output.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def litho_faults_ids(self):
127127
faults_ids = BackendTensor.t.rint(self.faults_block)
128128

129129
# Get the number of unique lithology IDs
130-
multiplier = len(np.unique(litho_ids))
131-
130+
multiplier = len(BackendTensor.t.unique(litho_ids))
131+
132132
# Generate the unique IDs
133133
unique_ids = litho_ids + faults_ids * multiplier
134134
return unique_ids
@@ -154,6 +154,7 @@ def get_block_from_value_type(self, value_type: ValueType, slice_: slice):
154154

155155
match (BackendTensor.engine_backend):
156156
case AvailableBackends.PYTORCH:
157-
block = block.detach().numpy()
157+
block = BackendTensor.t.to_numpy(block)
158+
# block = block.to_numpy()
158159

159160
return block[slice_]

gempy_engine/core/data/regular_grid.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55

6+
from ..backend_tensor import BackendTensor
67
from ..utils import _check_and_convert_list_to_array, cast_type_inplace
78
from .kernel_classes.server.input_parser import GridSchema
89

@@ -27,7 +28,7 @@ def __post_init__(self):
2728

2829
self._create_regular_grid_3d()
2930

30-
self.original_values = self.values.copy()
31+
self.original_values = BackendTensor.t.copy(self.values)
3132

3233
@property
3334
def dx(self):
@@ -196,5 +197,6 @@ def _create_regular_grid_3d(self):
196197
g = np.meshgrid(*coords, indexing="ij")
197198
values = np.vstack(tuple(map(np.ravel, g))).T.astype("float64")
198199
values = np.ascontiguousarray(values)
200+
values = BackendTensor.tfnp.array(values, dtype=BackendTensor.dtype_obj)
199201

200202
self.values = values

gempy_engine/core/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@ def cast_type_inplace(data_instance: Any, requires_grad:bool = False):
1212
case (gempy_engine.config.AvailableBackends.numpy):
1313
data_instance.__dict__[key] = val.astype(BackendTensor.dtype)
1414
case (gempy_engine.config.AvailableBackends.PYTORCH):
15-
tensor = BackendTensor.t.from_numpy(val.astype(BackendTensor.dtype))
15+
# tensor = BackendTensor.t.from_numpy(val.astype(BackendTensor.dtype))
16+
# if (BackendTensor.use_gpu):
17+
# tensor = tensor.cuda()
18+
19+
tensor = BackendTensor.tfnp.array(val, dtype=BackendTensor.dtype_obj)
1620
tensor.requires_grad = requires_grad
1721
data_instance.__dict__[key] = tensor
1822

0 commit comments

Comments
 (0)