@@ -42,6 +42,19 @@ void bfloat16_copy_kernel_cuda(TensorIteratorBase &iter) {
4242 });
4343}
4444
45+ #ifdef USE_ROCM
46+ void bfloat16tofloat32_copy_kernel_cuda (TensorIteratorBase &iter) {
47+ gpu_kernel_nocast (iter, [] GPU_LAMBDA (at::BFloat16 value) {
48+ return static_cast <float >(value);
49+ });
50+ }
51+ void float16tofloat32_copy_kernel_cuda (TensorIteratorBase &iter) {
52+ gpu_kernel_nocast (iter, [] GPU_LAMBDA (at::Half value) {
53+ return static_cast <float >(value);
54+ });
55+ }
56+ #endif
57+
4558void float8_copy_kernel_cuda (TensorIteratorBase &iter) {
4659 ScalarType dtype = iter.dtype (0 );
4760 ScalarType other_dtype = iter.dtype (1 );
@@ -187,7 +200,17 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
187200 } else {
188201 float16_copy_kernel_cuda (iter);
189202 }
190- } else if (isBitsType (dtype)) {
203+ }
204+ #ifdef USE_ROCM
205+ else if ((iter.dtype (1 ) == kBFloat16 || iter.dtype (1 ) == kHalf ) && dtype == kFloat ) {
206+ if (iter.dtype (1 ) == kBFloat16 ) {
207+ bfloat16tofloat32_copy_kernel_cuda (iter);
208+ } else {
209+ float16tofloat32_copy_kernel_cuda (iter);
210+ }
211+ }
212+ #endif
213+ else if (isBitsType (dtype)) {
191214 TORCH_CHECK (dtype == iter.dtype (1 ), " copy_() does not support casting "
192215 " bits types to different bits types. Source dtype is " , iter.dtype (1 ), " target dtype is " , dtype);
193216 AT_DISPATCH_BIT_TYPES (dtype, " copy_" , [&] {
0 commit comments