Skip to content

Commit 32a465b

Browse files
tklausenfacebook-github-bot
authored andcommitted
Fix GPU-CPU device mismatch error in util filter_dilated_rows (#633)
Summary: ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Docs change / refactoring / dependency upgrade ## Motivation and Context / Related issue The function `filter_dilated_rows` in `tensor_utils.py` converts a tensor to an ndarray, modifies the ndarray, and converts the modified ndarray back to a tensor. **Bug:** If the original tensor is not on the CPU, the conversion to ndarray will fail because tensor.cpu() is not called. ``` File "opacus/utils/tensor_utils.py", line 328, in filter_dilated_rows tensor_np = tensor.numpy() TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first. ``` **Fix:** This PR directly modifies the tensor without ever converting it to an ndarray. This fixes the bug and is more efficient than the original implementation. ## How Has This Been Tested (if it applies) Manually tested with the example provided in the function's DocString. Also, `filter_dilated_rows` is called if the dilation of a 3d convolution is not 1. Thus, this function is implicitly tested by `tests/grad_samples/conv3d_test.py`. ## Checklist - [x] The documentation is up-to-date with the changes I made. - [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**). - [x] All tests passed, and additional code has been covered with new tests. Pull Request resolved: #633 Reviewed By: karthikprasad Differential Revision: D54199129 fbshipit-source-id: 56026a8f298517e27b67cf77de06f94ab63d0a9c
1 parent ac639af commit 32a465b

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

opacus/utils/tensor_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -322,14 +322,15 @@ def filter_dilated_rows(
322322
kernel_rank = len(kernel_size)
323323

324324
indices_to_keep = [
325-
list(range(0, dilated_kernel_size[i], dilation[i])) for i in range(kernel_rank)
325+
torch.arange(0, dilated_kernel_size[i], dilation[i], device=tensor.device)
326+
for i in range(kernel_rank)
326327
]
327328

328-
tensor_np = tensor.numpy()
329-
330329
axis_offset = len(tensor.shape) - kernel_rank
331330

332331
for dim in range(kernel_rank):
333-
tensor_np = np.take(tensor_np, indices_to_keep[dim], axis=axis_offset + dim)
332+
tensor = torch.index_select(
333+
tensor, dim=axis_offset + dim, index=indices_to_keep[dim]
334+
)
334335

335-
return torch.Tensor(tensor_np)
336+
return tensor

0 commit comments

Comments
 (0)