diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index ad478431c..af3c044b8 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -188,7 +188,10 @@ def transform( if HIP_ENVIRONMENT: # transform kernel formats (col32/col_turing/col_ampere) are not applicable to ROCm # Use nvidia_transform instead - return nvidia_transform(A, to_order, from_order, out, transpose, state, ld) + prev_device = pre_call(A.device) + out, new_state = nvidia_transform(A, to_order, from_order, out, transpose, state, ld) + post_call(prev_device) + return out, new_state prev_device = pre_call(A.device) if state is None: