1515#include < ATen/cuda/CUDAContext.h>
1616#include < cuda.h>
1717#include < cuda_runtime.h>
18+ #include < torch/extension.h>
19+ #include < torch/version.h>
20+
21+ // Check PyTorch version and define appropriate macros
22+ #if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 6
23+ // PyTorch 2.x and above
24+ #define GET_TENSOR_TYPE (x ) x.scalar_type()
25+ #define IS_CUDA_TENSOR (x ) x.device().is_cuda()
26+ #else
27+ // PyTorch 1.x
28+ #define GET_TENSOR_TYPE (x ) x.type()
29+ #define IS_CUDA_TENSOR (x ) x.type().is_cuda()
30+ #endif
1831
1932namespace groundingdino {
2033
2134at::Tensor ms_deform_attn_cuda_forward (
22- const at::Tensor &value,
35+ const at::Tensor &value,
2336 const at::Tensor &spatial_shapes,
2437 const at::Tensor &level_start_index,
2538 const at::Tensor &sampling_loc,
@@ -32,11 +45,11 @@ at::Tensor ms_deform_attn_cuda_forward(
3245 AT_ASSERTM (sampling_loc.is_contiguous (), " sampling_loc tensor has to be contiguous" );
3346 AT_ASSERTM (attn_weight.is_contiguous (), " attn_weight tensor has to be contiguous" );
3447
35- AT_ASSERTM (value. type (). is_cuda ( ), " value must be a CUDA tensor" );
36- AT_ASSERTM (spatial_shapes. type (). is_cuda ( ), " spatial_shapes must be a CUDA tensor" );
37- AT_ASSERTM (level_start_index. type (). is_cuda ( ), " level_start_index must be a CUDA tensor" );
38- AT_ASSERTM (sampling_loc. type (). is_cuda ( ), " sampling_loc must be a CUDA tensor" );
39- AT_ASSERTM (attn_weight. type (). is_cuda ( ), " attn_weight must be a CUDA tensor" );
48+ AT_ASSERTM (IS_CUDA_TENSOR (value ), " value must be a CUDA tensor" );
49+ AT_ASSERTM (IS_CUDA_TENSOR (spatial_shapes ), " spatial_shapes must be a CUDA tensor" );
50+ AT_ASSERTM (IS_CUDA_TENSOR (level_start_index ), " level_start_index must be a CUDA tensor" );
51+ AT_ASSERTM (IS_CUDA_TENSOR (sampling_loc ), " sampling_loc must be a CUDA tensor" );
52+ AT_ASSERTM (IS_CUDA_TENSOR (attn_weight ), " attn_weight must be a CUDA tensor" );
4053
4154 const int batch = value.size (0 );
4255 const int spatial_size = value.size (1 );
@@ -51,7 +64,7 @@ at::Tensor ms_deform_attn_cuda_forward(
5164 const int im2col_step_ = std::min (batch, im2col_step);
5265
5366 AT_ASSERTM (batch % im2col_step_ == 0 , " batch(%d) must divide im2col_step(%d)" , batch, im2col_step_);
54-
67+
5568 auto output = at::zeros ({batch, num_query, num_heads, channels}, value.options ());
5669
5770 const int batch_n = im2col_step_;
@@ -62,7 +75,7 @@ at::Tensor ms_deform_attn_cuda_forward(
6275 for (int n = 0 ; n < batch/im2col_step_; ++n)
6376 {
6477 auto columns = output_n.select (0 , n);
65- AT_DISPATCH_FLOATING_TYPES (value. type ( ), " ms_deform_attn_forward_cuda" , ([&] {
78+ AT_DISPATCH_FLOATING_TYPES (GET_TENSOR_TYPE (value ), " ms_deform_attn_forward_cuda" , ([&] {
6679 ms_deformable_im2col_cuda (at::cuda::getCurrentCUDAStream (),
6780 value.data <scalar_t >() + n * im2col_step_ * per_value_size,
6881 spatial_shapes.data <int64_t >(),
@@ -82,7 +95,7 @@ at::Tensor ms_deform_attn_cuda_forward(
8295
8396
8497std::vector<at::Tensor> ms_deform_attn_cuda_backward (
85- const at::Tensor &value,
98+ const at::Tensor &value,
8699 const at::Tensor &spatial_shapes,
87100 const at::Tensor &level_start_index,
88101 const at::Tensor &sampling_loc,
@@ -98,12 +111,12 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
98111 AT_ASSERTM (attn_weight.is_contiguous (), " attn_weight tensor has to be contiguous" );
99112 AT_ASSERTM (grad_output.is_contiguous (), " grad_output tensor has to be contiguous" );
100113
101- AT_ASSERTM (value. type (). is_cuda ( ), " value must be a CUDA tensor" );
102- AT_ASSERTM (spatial_shapes. type (). is_cuda ( ), " spatial_shapes must be a CUDA tensor" );
103- AT_ASSERTM (level_start_index. type (). is_cuda ( ), " level_start_index must be a CUDA tensor" );
104- AT_ASSERTM (sampling_loc. type (). is_cuda ( ), " sampling_loc must be a CUDA tensor" );
105- AT_ASSERTM (attn_weight. type (). is_cuda ( ), " attn_weight must be a CUDA tensor" );
106- AT_ASSERTM (grad_output. type (). is_cuda ( ), " grad_output must be a CUDA tensor" );
114+ AT_ASSERTM (IS_CUDA_TENSOR (value ), " value must be a CUDA tensor" );
115+ AT_ASSERTM (IS_CUDA_TENSOR (spatial_shapes ), " spatial_shapes must be a CUDA tensor" );
116+ AT_ASSERTM (IS_CUDA_TENSOR (level_start_index ), " level_start_index must be a CUDA tensor" );
117+ AT_ASSERTM (IS_CUDA_TENSOR (sampling_loc ), " sampling_loc must be a CUDA tensor" );
118+ AT_ASSERTM (IS_CUDA_TENSOR (attn_weight ), " attn_weight must be a CUDA tensor" );
119+ AT_ASSERTM (IS_CUDA_TENSOR (grad_output ), " grad_output must be a CUDA tensor" );
107120
108121 const int batch = value.size (0 );
109122 const int spatial_size = value.size (1 );
@@ -128,11 +141,11 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
128141 auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2 ;
129142 auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
130143 auto grad_output_n = grad_output.view ({batch/im2col_step_, batch_n, num_query, num_heads, channels});
131-
144+
132145 for (int n = 0 ; n < batch/im2col_step_; ++n)
133146 {
134147 auto grad_output_g = grad_output_n.select (0 , n);
135- AT_DISPATCH_FLOATING_TYPES (value. type ( ), " ms_deform_attn_backward_cuda" , ([&] {
148+ AT_DISPATCH_FLOATING_TYPES (GET_TENSOR_TYPE (value ), " ms_deform_attn_backward_cuda" , ([&] {
136149 ms_deformable_col2im_cuda (at::cuda::getCurrentCUDAStream (),
137150 grad_output_g.data <scalar_t >(),
138151 value.data <scalar_t >() + n * im2col_step_ * per_value_size,
@@ -153,4 +166,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
153166 };
154167}
155168
156- } // namespace groundingdino
169+ } // namespace groundingdino
0 commit comments