@@ -999,12 +999,41 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
999999 dtypes[i] = iter.dtype (i);
10001000 }
10011001 auto offset_calc = ::make_offset_calculator<traits::arity + 1 >(iter);
1002+ #ifdef USE_ROCM
1003+ constexpr int grp_sz = 128 ;
1004+ launch_legacy_kernel_manual_unroll<grp_sz, 4 >(numel, [=] GPU_LAMBDA (int idx, bool unrl) {
1005+ if (unrl) {
1006+ auto offsets0 = offset_calc.get (idx);
1007+ auto offsets1 = offset_calc.get (idx + grp_sz);
1008+ auto offsets2 = offset_calc.get (idx + grp_sz * 2 );
1009+ auto offsets3 = offset_calc.get (idx + grp_sz * 3 );
1010+ void * out0 = data[0 ] + offsets0[0 ];
1011+ void * out1 = data[0 ] + offsets1[0 ];
1012+ void * out2 = data[0 ] + offsets2[0 ];
1013+ void * out3 = data[0 ] + offsets3[0 ];
1014+ arg0_t result0 = invoke (f, &data[1 ], &offsets0[1 ], &dtypes[1 ], 1 );
1015+ arg0_t result1 = invoke (f, &data[1 ], &offsets1[1 ], &dtypes[1 ], 1 );
1016+ arg0_t result2 = invoke (f, &data[1 ], &offsets2[1 ], &dtypes[1 ], 1 );
1017+ arg0_t result3 = invoke (f, &data[1 ], &offsets3[1 ], &dtypes[1 ], 1 );
1018+ c10::cast_and_store<arg0_t >(dtypes[0 ], out0, result0);
1019+ c10::cast_and_store<arg0_t >(dtypes[0 ], out1, result1);
1020+ c10::cast_and_store<arg0_t >(dtypes[0 ], out2, result2);
1021+ c10::cast_and_store<arg0_t >(dtypes[0 ], out3, result3);
1022+ } else {
1023+ auto offsets = offset_calc.get (idx);
1024+ void * out = data[0 ] + offsets[0 ];
1025+ arg0_t result = invoke (f, &data[1 ], &offsets[1 ], &dtypes[1 ], 1 );
1026+ c10::cast_and_store<arg0_t >(dtypes[0 ], out, result);
1027+ }
1028+ });
1029+ #else
10021030 launch_legacy_kernel<128 , 4 >(numel, [=] GPU_LAMBDA (int idx) {
10031031 auto offsets = offset_calc.get (idx);
10041032 void * out = data[0 ] + offsets[0 ];
10051033 arg0_t result = invoke (f, &data[1 ], &offsets[1 ], &dtypes[1 ], 1 );
10061034 c10::cast_and_store<arg0_t >(dtypes[0 ], out, result);
10071035 });
1036+ #endif
10081037 }
10091038}
10101039
0 commit comments