@@ -65,101 +65,59 @@ ur_result_t urCalculateNumChannels(ur_image_channel_order_t order,
6565// / format if not nullptr.
6666// / /param return_pixel_size_bytes will be set to the pixel
6767// / byte size if not nullptr.
68+ // / /param return_normalized_dtype_flag will be set if the
69+ // / data type is normalized if not nullptr.
6870ur_result_t
6971urToCudaImageChannelFormat (ur_image_channel_type_t image_channel_type,
7072 ur_image_channel_order_t image_channel_order,
7173 CUarray_format *return_cuda_format,
72- size_t *return_pixel_size_bytes) {
74+ size_t *return_pixel_size_bytes,
75+ unsigned int *return_normalized_dtype_flag) {
7376
74- CUarray_format cuda_format;
77+ CUarray_format cuda_format = CU_AD_FORMAT_UNSIGNED_INT8 ;
7578 size_t pixel_size_bytes = 0 ;
7679 unsigned int num_channels = 0 ;
80+ unsigned int normalized_dtype_flag = 0 ;
7781 UR_CHECK_ERROR (urCalculateNumChannels (image_channel_order, &num_channels));
7882
7983 switch (image_channel_type) {
80- #define CASE (FROM, TO, SIZE ) \
84+ #define CASE (FROM, TO, SIZE, NORM ) \
8185 case FROM: { \
8286 cuda_format = TO; \
8387 pixel_size_bytes = SIZE * num_channels; \
88+ normalized_dtype_flag = NORM; \
8489 break ; \
8590 }
8691
87- CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8, CU_AD_FORMAT_UNSIGNED_INT8, 1 )
88- CASE (UR_IMAGE_CHANNEL_TYPE_SIGNED_INT8, CU_AD_FORMAT_SIGNED_INT8, 1 )
89- CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16, CU_AD_FORMAT_UNSIGNED_INT16, 2 )
90- CASE (UR_IMAGE_CHANNEL_TYPE_SIGNED_INT16, CU_AD_FORMAT_SIGNED_INT16, 2 )
91- CASE (UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT, CU_AD_FORMAT_HALF, 2 )
92- CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32, CU_AD_FORMAT_UNSIGNED_INT32, 4 )
93- CASE (UR_IMAGE_CHANNEL_TYPE_SIGNED_INT32, CU_AD_FORMAT_SIGNED_INT32, 4 )
94- CASE (UR_IMAGE_CHANNEL_TYPE_FLOAT, CU_AD_FORMAT_FLOAT, 4 )
92+ CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8, CU_AD_FORMAT_UNSIGNED_INT8, 1 , 0 )
93+ CASE (UR_IMAGE_CHANNEL_TYPE_SIGNED_INT8, CU_AD_FORMAT_SIGNED_INT8, 1 , 0 )
94+ CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16, CU_AD_FORMAT_UNSIGNED_INT16, 2 ,
95+ 0 )
96+ CASE (UR_IMAGE_CHANNEL_TYPE_SIGNED_INT16, CU_AD_FORMAT_SIGNED_INT16, 2 , 0 )
97+ CASE (UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT, CU_AD_FORMAT_HALF, 2 , 0 )
98+ CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32, CU_AD_FORMAT_UNSIGNED_INT32, 4 ,
99+ 0 )
100+ CASE (UR_IMAGE_CHANNEL_TYPE_SIGNED_INT32, CU_AD_FORMAT_SIGNED_INT32, 4 , 0 )
101+ CASE (UR_IMAGE_CHANNEL_TYPE_FLOAT, CU_AD_FORMAT_FLOAT, 4 , 0 )
102+ CASE (UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, CU_AD_FORMAT_UNSIGNED_INT8, 1 , 1 )
103+ CASE (UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, CU_AD_FORMAT_SIGNED_INT8, 1 , 1 )
104+ CASE (UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, CU_AD_FORMAT_UNSIGNED_INT16, 2 , 1 )
105+ CASE (UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, CU_AD_FORMAT_SIGNED_INT16, 2 , 1 )
95106
96107#undef CASE
97108 default :
98109 break ;
99110 }
100111
101- // These new formats were brought in in CUDA 11.5
102- #if CUDA_VERSION >= 11050
103-
104- // If none of the above channel types were passed, check those below
105- if (pixel_size_bytes == 0 ) {
106-
107- // We can't use a switch statement here because these single
108- // UR_IMAGE_CHANNEL_TYPEs can correspond to multiple [u/s]norm CU_AD_FORMATs
109- // depending on the number of channels. We use a std::map instead to
110- // retrieve the correct CUDA format
111-
112- // map < <channel type, num channels> , <CUDA format, data type byte size> >
113- const std::map<std::pair<ur_image_channel_type_t , uint32_t >,
114- std::pair<CUarray_format, uint32_t >>
115- norm_channel_type_map{
116- {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, 1 },
117- {CU_AD_FORMAT_UNORM_INT8X1, 1 }},
118- {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, 2 },
119- {CU_AD_FORMAT_UNORM_INT8X2, 2 }},
120- {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, 4 },
121- {CU_AD_FORMAT_UNORM_INT8X4, 4 }},
122-
123- {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, 1 },
124- {CU_AD_FORMAT_SNORM_INT8X1, 1 }},
125- {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, 2 },
126- {CU_AD_FORMAT_SNORM_INT8X2, 2 }},
127- {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, 4 },
128- {CU_AD_FORMAT_SNORM_INT8X4, 4 }},
129-
130- {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, 1 },
131- {CU_AD_FORMAT_UNORM_INT16X1, 2 }},
132- {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, 2 },
133- {CU_AD_FORMAT_UNORM_INT16X2, 4 }},
134- {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, 4 },
135- {CU_AD_FORMAT_UNORM_INT16X4, 8 }},
136-
137- {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, 1 },
138- {CU_AD_FORMAT_SNORM_INT16X1, 2 }},
139- {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, 2 },
140- {CU_AD_FORMAT_SNORM_INT16X2, 4 }},
141- {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, 4 },
142- {CU_AD_FORMAT_SNORM_INT16X4, 8 }},
143- };
144-
145- try {
146- auto cuda_format_and_size = norm_channel_type_map.at (
147- std::make_pair (image_channel_type, num_channels));
148- cuda_format = cuda_format_and_size.first ;
149- pixel_size_bytes = cuda_format_and_size.second ;
150- } catch (const std::out_of_range &) {
151- return UR_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT;
152- }
153- }
154-
155- #endif
156-
157112 if (return_cuda_format) {
158113 *return_cuda_format = cuda_format;
159114 }
160115 if (return_pixel_size_bytes) {
161116 *return_pixel_size_bytes = pixel_size_bytes;
162117 }
118+ if (return_normalized_dtype_flag) {
119+ *return_normalized_dtype_flag = normalized_dtype_flag;
120+ }
163121 return UR_RESULT_SUCCESS;
164122}
165123
@@ -189,53 +147,17 @@ cudaToUrImageChannelFormat(CUarray_format cuda_format,
189147 UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT);
190148 CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_FLOAT,
191149 UR_IMAGE_CHANNEL_TYPE_FLOAT);
192- #if CUDA_VERSION >= 11050
193-
194- // Note that the CUDA UNORM and SNORM formats also encode the number of
195- // channels.
196- // Since UR does not encode this, we map different CUDA formats to the same
197- // UR channel type.
198- // Since this function is only called from `urBindlessImagesImageGetInfoExp`
199- // which has access to `CUDA_ARRAY3D_DESCRIPTOR`, we can determine the
200- // number of channels in the calling function.
201-
202- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT8X1,
203- UR_IMAGE_CHANNEL_TYPE_UNORM_INT8);
204- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT8X2,
205- UR_IMAGE_CHANNEL_TYPE_UNORM_INT8);
206- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT8X4,
207- UR_IMAGE_CHANNEL_TYPE_UNORM_INT8);
208-
209- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT16X1,
210- UR_IMAGE_CHANNEL_TYPE_UNORM_INT16);
211- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT16X2,
212- UR_IMAGE_CHANNEL_TYPE_UNORM_INT16);
213- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT16X4,
214- UR_IMAGE_CHANNEL_TYPE_UNORM_INT16);
215-
216- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT8X1,
217- UR_IMAGE_CHANNEL_TYPE_SNORM_INT8);
218- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT8X2,
219- UR_IMAGE_CHANNEL_TYPE_SNORM_INT8);
220- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT8X4,
221- UR_IMAGE_CHANNEL_TYPE_SNORM_INT8);
222-
223- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT16X1,
224- UR_IMAGE_CHANNEL_TYPE_SNORM_INT16);
225- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT16X2,
226- UR_IMAGE_CHANNEL_TYPE_SNORM_INT16);
227- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT16X4,
228- UR_IMAGE_CHANNEL_TYPE_SNORM_INT16);
229- #endif
230- #undef MAP
231150 default :
151+ // Default invalid enum
152+ *return_image_channel_type = UR_IMAGE_CHANNEL_TYPE_FORCE_UINT32;
232153 return UR_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT;
233154 }
234155}
235156
236157ur_result_t urTextureCreate (ur_sampler_handle_t hSampler,
237158 const ur_image_desc_t *pImageDesc,
238159 const CUDA_RESOURCE_DESC &ResourceDesc,
160+ const unsigned int normalized_dtype_flag,
239161 ur_exp_image_native_handle_t *phRetImage) {
240162
241163 try {
@@ -306,8 +228,9 @@ ur_result_t urTextureCreate(ur_sampler_handle_t hSampler,
306228
307229 // CUDA default promotes 8-bit and 16-bit integers to float between [0,1]
308230 // This flag prevents this behaviour.
309- ImageTexDesc.flags |= CU_TRSF_READ_AS_INTEGER;
310-
231+ if (!normalized_dtype_flag) {
232+ ImageTexDesc.flags |= CU_TRSF_READ_AS_INTEGER;
233+ }
311234 // Cubemap attributes
312235 ur_exp_sampler_cubemap_filter_mode_t CubemapFilterModeProp =
313236 hSampler->getCubemapFilterMode ();
@@ -413,9 +336,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImageAllocateExp(
413336 UR_CHECK_ERROR (urCalculateNumChannels (pImageFormat->channelOrder ,
414337 &array_desc.NumChannels ));
415338
416- UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat-> channelType ,
417- pImageFormat->channelOrder ,
418- &array_desc. Format , nullptr ));
339+ UR_CHECK_ERROR (urToCudaImageChannelFormat (
340+ pImageFormat-> channelType , pImageFormat->channelOrder , &array_desc. Format ,
341+ nullptr , nullptr ));
419342
420343 array_desc.Flags = 0 ; // No flags required
421344 array_desc.Width = pImageDesc->width ;
@@ -534,7 +457,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesUnsampledImageCreateExp(
534457 size_t PixelSizeBytes;
535458 UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat->channelType ,
536459 pImageFormat->channelOrder , &format,
537- &PixelSizeBytes));
460+ &PixelSizeBytes, nullptr ));
538461
539462 try {
540463
@@ -579,9 +502,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp(
579502
580503 CUarray_format format;
581504 size_t PixelSizeBytes;
582- UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat->channelType ,
583- pImageFormat->channelOrder , &format,
584- &PixelSizeBytes));
505+ unsigned int normalized_dtype_flag;
506+ UR_CHECK_ERROR (urToCudaImageChannelFormat (
507+ pImageFormat->channelType , pImageFormat->channelOrder , &format,
508+ &PixelSizeBytes, &normalized_dtype_flag));
585509
586510 try {
587511 CUDA_RESOURCE_DESC image_res_desc = {};
@@ -630,8 +554,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp(
630554 return UR_RESULT_ERROR_INVALID_VALUE;
631555 }
632556
633- UR_CHECK_ERROR (
634- urTextureCreate (hSampler, pImageDesc, image_res_desc , phImage));
557+ UR_CHECK_ERROR (urTextureCreate (hSampler, pImageDesc, image_res_desc,
558+ normalized_dtype_flag , phImage));
635559
636560 } catch (ur_result_t Err) {
637561 return Err;
@@ -671,7 +595,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImageCopyExp(
671595 // later.
672596 UR_CHECK_ERROR (urToCudaImageChannelFormat (pSrcImageFormat->channelType ,
673597 pSrcImageFormat->channelOrder ,
674- nullptr , &PixelSizeBytes));
598+ nullptr , &PixelSizeBytes, nullptr ));
675599
676600 try {
677601 ScopedContext Active (hQueue->getDevice ());
@@ -1150,8 +1074,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesMapExternalArrayExp(
11501074 urCalculateNumChannels (pImageFormat->channelOrder , &NumChannels));
11511075
11521076 CUarray_format format;
1153- UR_CHECK_ERROR (urToCudaImageChannelFormat (
1154- pImageFormat->channelType , pImageFormat->channelOrder , &format, nullptr ));
1077+ UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat->channelType ,
1078+ pImageFormat->channelOrder , &format,
1079+ nullptr , nullptr ));
11551080
11561081 try {
11571082 ScopedContext Active (hDevice);
0 commit comments