@@ -107,9 +107,20 @@ def _change_backend(cls, engine_backend: AvailableBackends, use_pykeops: bool =
107107 if (use_gpu ):
108108 cls .use_gpu = True
109109 # cls.tensor_backend_pointer['active_backend'].set_default_device("cuda")
110+ # Check if CUDA is available
111+ if not pytorch_copy .cuda .is_available ():
112+ raise RuntimeError ("GPU requested but CUDA is not available in PyTorch" )
113+ if False :
114+ # Set default device to CUDA
115+ cls .device = pytorch_copy .device ("cuda" )
116+ pytorch_copy .set_default_device ("cuda" )
117+ print (f"GPU enabled. Using device: { cls .device } " )
118+ print (f"GPU device count: { pytorch_copy .cuda .device_count ()} " )
119+ print (f"Current GPU device: { pytorch_copy .cuda .current_device ()} " )
110120 else :
111121 cls .use_gpu = False
112-
122+ cls .device = pytorch_copy .device ("cpu" )
123+ pytorch_copy .set_default_device ("cpu" )
113124 case (_):
114125 raise AttributeError (
115126 f"Engine Backend: { engine_backend } cannot be used because the correspondent library"
@@ -134,7 +145,7 @@ def describe_conf(cls):
134145
135146 @classmethod
136147 def _wrap_pytorch_functions (cls ):
137- from torch import sum , repeat_interleave
148+ from torch import sum , repeat_interleave , isclose
138149 import torch
139150
140151 def _sum (tensor , axis = None , dtype = None , keepdims = False ):
@@ -155,6 +166,11 @@ def _array(array_like, dtype=None):
155166 if dtype is None : return array_like
156167 else : return array_like .type (dtype )
157168 else :
169+ # Ensure numpy arrays are contiguous before converting to torch tensor
170+ if isinstance (array_like , numpy .ndarray ):
171+ if not array_like .flags .c_contiguous :
172+ array_like = numpy .ascontiguousarray (array_like )
173+
158174 return torch .tensor (array_like , dtype = dtype )
159175
160176 def _concatenate (tensors , axis = 0 , dtype = None ):
@@ -167,6 +183,95 @@ def _concatenate(tensors, axis=0, dtype=None):
167183
168184 def _transpose (tensor , axes = None ):
169185 return tensor .transpose (axes [0 ], axes [1 ])
186+
187+
188+ def _packbits (tensor , axis = None , bitorder = "big" ):
189+ """
190+ Pack boolean values into uint8 bytes along the specified axis.
191+ For a (4, n) tensor with axis=0, this packs every 4 bits into nibbles,
192+ then pads to create full bytes.
193+ """
194+ # Convert to uint8 if boolean
195+ if tensor .dtype == torch .bool :
196+ tensor = tensor .to (torch .uint8 )
197+
198+ if axis == 0 :
199+ # Pack along axis 0 (rows)
200+ n_rows , n_cols = tensor .shape
201+ n_output_rows = (n_rows + 7 ) // 8 # Round up to nearest byte boundary
202+
203+ # Pad with zeros if we don't have multiples of 8 rows
204+ if n_rows % 8 != 0 :
205+ padding_rows = 8 - (n_rows % 8 )
206+ padding = torch .zeros (padding_rows , n_cols , dtype = torch .uint8 , device = tensor .device )
207+ tensor = torch .cat ([tensor , padding ], dim = 0 )
208+
209+ # Reshape to group every 8 rows together: (n_output_rows, 8, n_cols)
210+ tensor_reshaped = tensor .view (n_output_rows , 8 , n_cols )
211+
212+ # Define bit positions (powers of 2)
213+ if bitorder == "little" :
214+ # Little endian: LSB first [1, 2, 4, 8, 16, 32, 64, 128]
215+ powers = torch .tensor ([1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 ],
216+ dtype = torch .uint8 , device = tensor .device ).view (1 , 8 , 1 )
217+ else :
218+ # Big endian: MSB first [128, 64, 32, 16, 8, 4, 2, 1]
219+ powers = torch .tensor ([128 , 64 , 32 , 16 , 8 , 4 , 2 , 1 ],
220+ dtype = torch .uint8 , device = tensor .device ).view (1 , 8 , 1 )
221+
222+ # Pack bits: multiply each bit by its power and sum along the 8-bit dimension
223+ packed = (tensor_reshaped * powers ).sum (dim = 1 ) # Shape: (n_output_rows, n_cols)
224+
225+ return packed
226+
227+ elif axis == 1 :
228+ # Pack along axis 1 (columns)
229+ n_rows , n_cols = tensor .shape
230+ n_output_cols = (n_cols + 7 ) // 8
231+
232+ # Pad with zeros if needed
233+ if n_cols % 8 != 0 :
234+ padding_cols = 8 - (n_cols % 8 )
235+ padding = torch .zeros (n_rows , padding_cols , dtype = torch .uint8 , device = tensor .device )
236+ tensor = torch .cat ([tensor , padding ], dim = 1 )
237+
238+ # Reshape: (n_rows, n_output_cols, 8)
239+ tensor_reshaped = tensor .view (n_rows , n_output_cols , 8 )
240+
241+ # Define bit positions
242+ if bitorder == "little" :
243+ powers = torch .tensor ([1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 ],
244+ dtype = torch .uint8 , device = tensor .device ).view (1 , 1 , 8 )
245+ else :
246+ powers = torch .tensor ([128 , 64 , 32 , 16 , 8 , 4 , 2 , 1 ],
247+ dtype = torch .uint8 , device = tensor .device ).view (1 , 1 , 8 )
248+
249+ packed = (tensor_reshaped * powers ).sum (dim = 2 ) # Shape: (n_rows, n_output_cols)
250+ return packed
251+
252+ else :
253+ raise NotImplementedError (f"packbits not implemented for axis={ axis } " )
254+
255+
256+ def _to_numpy (tensor ):
257+ """Convert tensor to numpy array, handling GPU tensors properly"""
258+ if hasattr (tensor , 'device' ) and tensor .device .type == 'cuda' :
259+ # Move to CPU first, then detach and convert to numpy
260+ return tensor .cpu ().detach ().numpy ()
261+ elif hasattr (tensor , 'detach' ):
262+ # CPU tensor, just detach and convert
263+ return tensor .detach ().numpy ()
264+ else :
265+ # Not a torch tensor, return as-is
266+ return tensor
267+
268+ def _fill_diagonal (tensor , value ):
269+ """Fill the diagonal of a 2D tensor with the given value"""
270+ if tensor .dim () != 2 :
271+ raise ValueError ("fill_diagonal only supports 2D tensors" )
272+ diagonal_indices = torch .arange (min (tensor .size (0 ), tensor .size (1 )))
273+ tensor [diagonal_indices , diagonal_indices ] = value
274+ return tensor
170275
171276 cls .tfnp .sum = _sum
172277 cls .tfnp .repeat = _repeat
@@ -175,7 +280,7 @@ def _transpose(tensor, axes=None):
175280 cls .tfnp .flip = lambda tensor , axis : tensor .flip (axis )
176281 cls .tfnp .hstack = lambda tensors : torch .concat (tensors , dim = 1 )
177282 cls .tfnp .array = _array
178- cls .tfnp .to_numpy = lambda tensor : tensor . detach (). numpy ()
283+ cls .tfnp .to_numpy = _to_numpy
179284 cls .tfnp .min = lambda tensor , axis : tensor .min (axis = axis )[0 ]
180285 cls .tfnp .max = lambda tensor , axis : tensor .max (axis = axis )[0 ]
181286 cls .tfnp .rint = lambda tensor : tensor .round ().type (torch .int32 )
@@ -185,6 +290,17 @@ def _transpose(tensor, axes=None):
185290 cls .tfnp .transpose = _transpose
186291 cls .tfnp .geomspace = lambda start , stop , step : torch .logspace (start , stop , step , base = 10 )
187292 cls .tfnp .abs = lambda tensor , dtype = None : tensor .abs ().type (dtype ) if dtype is not None else tensor .abs ()
293+ cls .tfnp .tile = lambda tensor , repeats : tensor .repeat (repeats )
294+ cls .tfnp .ravel = lambda tensor : tensor .flatten ()
295+ cls .tfnp .packbits = _packbits
296+ cls .tfnp .fill_diagonal = _fill_diagonal
297+ cls .tfnp .isclose = lambda a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False : isclose (
298+ a ,
299+ torch .tensor (b , dtype = a .dtype , device = a .device ),
300+ rtol = rtol ,
301+ atol = atol ,
302+ equal_nan = equal_nan
303+ )
188304
189305 @classmethod
190306 def _wrap_pykeops_functions (cls ):
0 commit comments