@@ -36,7 +36,6 @@ at::Tensor ms_deform_attn_cuda_forward(
3636 AT_ASSERTM (sampling_loc.is_contiguous (), " sampling_loc tensor has to be contiguous" );
3737 AT_ASSERTM (attn_weight.is_contiguous (), " attn_weight tensor has to be contiguous" );
3838
39- AT_ASSERTM (value.type ().is_cuda (), " value must be a CUDA tensor" );
4039 AT_ASSERTM (spatial_shapes.type ().is_cuda (), " spatial_shapes must be a CUDA tensor" );
4140 AT_ASSERTM (level_start_index.type ().is_cuda (), " level_start_index must be a CUDA tensor" );
4241 AT_ASSERTM (sampling_loc.type ().is_cuda (), " sampling_loc must be a CUDA tensor" );
@@ -66,7 +65,7 @@ at::Tensor ms_deform_attn_cuda_forward(
6665 for (int n = 0 ; n < batch/im2col_step_; ++n)
6766 {
6867 auto columns = output_n.select (0 , n);
69- AT_DISPATCH_FLOATING_TYPES (value.type (), " ms_deform_attn_forward_cuda" , ([&] {
68+ AT_DISPATCH_FLOATING_TYPES (value.scalar_type (), " ms_deform_attn_forward_cuda" , ([&] {
7069 ms_deformable_im2col_cuda (at::cuda::getCurrentCUDAStream (),
7170 value.data <scalar_t >() + n * im2col_step_ * per_value_size,
7271 spatial_shapes.data <int64_t >(),
@@ -102,7 +101,6 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
102101 AT_ASSERTM (attn_weight.is_contiguous (), " attn_weight tensor has to be contiguous" );
103102 AT_ASSERTM (grad_output.is_contiguous (), " grad_output tensor has to be contiguous" );
104103
105- AT_ASSERTM (value.type ().is_cuda (), " value must be a CUDA tensor" );
106104 AT_ASSERTM (spatial_shapes.type ().is_cuda (), " spatial_shapes must be a CUDA tensor" );
107105 AT_ASSERTM (level_start_index.type ().is_cuda (), " level_start_index must be a CUDA tensor" );
108106 AT_ASSERTM (sampling_loc.type ().is_cuda (), " sampling_loc must be a CUDA tensor" );
@@ -136,7 +134,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
136134 for (int n = 0 ; n < batch/im2col_step_; ++n)
137135 {
138136 auto grad_output_g = grad_output_n.select (0 , n);
139- AT_DISPATCH_FLOATING_TYPES (value.type (), " ms_deform_attn_backward_cuda" , ([&] {
137+ AT_DISPATCH_FLOATING_TYPES (value.scalar_type (), " ms_deform_attn_backward_cuda" , ([&] {
140138 ms_deformable_col2im_cuda (at::cuda::getCurrentCUDAStream (),
141139 grad_output_g.data <scalar_t >(),
142140 value.data <scalar_t >() + n * im2col_step_ * per_value_size,
0 commit comments