Skip to content

Commit 76c0986

Browse files
committed
1 parent 3831d85 commit 76c0986

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

maskdino/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ at::Tensor ms_deform_attn_cuda_forward(
3636
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
3737
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
3838

39-
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
4039
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
4140
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
4241
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
@@ -66,7 +65,7 @@ at::Tensor ms_deform_attn_cuda_forward(
6665
for (int n = 0; n < batch/im2col_step_; ++n)
6766
{
6867
auto columns = output_n.select(0, n);
69-
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
68+
AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
7069
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
7170
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
7271
spatial_shapes.data<int64_t>(),
@@ -102,7 +101,6 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
102101
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
103102
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
104103

105-
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
106104
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
107105
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
108106
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
@@ -136,7 +134,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
136134
for (int n = 0; n < batch/im2col_step_; ++n)
137135
{
138136
auto grad_output_g = grad_output_n.select(0, n);
139-
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
137+
AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
140138
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
141139
grad_output_g.data<scalar_t>(),
142140
value.data<scalar_t>() + n * im2col_step_ * per_value_size,

0 commit comments

Comments
 (0)