@@ -994,12 +994,41 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
994994 dtypes[i] = iter.dtype (i);
995995 }
996996 auto offset_calc = ::make_offset_calculator<traits::arity + 1 >(iter);
997+ #ifdef USE_ROCM
998+ constexpr int grp_sz = 128 ;
999+ launch_legacy_kernel_manual_unroll<grp_sz, 4 >(numel, [=] GPU_LAMBDA (int idx, bool unrl) {
1000+ if (unrl) {
1001+ auto offsets0 = offset_calc.get (idx);
1002+ auto offsets1 = offset_calc.get (idx + grp_sz);
1003+ auto offsets2 = offset_calc.get (idx + grp_sz * 2 );
1004+ auto offsets3 = offset_calc.get (idx + grp_sz * 3 );
1005+ void * out0 = data[0 ] + offsets0[0 ];
1006+ void * out1 = data[0 ] + offsets1[0 ];
1007+ void * out2 = data[0 ] + offsets2[0 ];
1008+ void * out3 = data[0 ] + offsets3[0 ];
1009+ arg0_t result0 = invoke (f, &data[1 ], &offsets0[1 ], &dtypes[1 ], 1 );
1010+ arg0_t result1 = invoke (f, &data[1 ], &offsets1[1 ], &dtypes[1 ], 1 );
1011+ arg0_t result2 = invoke (f, &data[1 ], &offsets2[1 ], &dtypes[1 ], 1 );
1012+ arg0_t result3 = invoke (f, &data[1 ], &offsets3[1 ], &dtypes[1 ], 1 );
1013+ c10::cast_and_store<arg0_t >(dtypes[0 ], out0, result0);
1014+ c10::cast_and_store<arg0_t >(dtypes[0 ], out1, result1);
1015+ c10::cast_and_store<arg0_t >(dtypes[0 ], out2, result2);
1016+ c10::cast_and_store<arg0_t >(dtypes[0 ], out3, result3);
1017+ } else {
1018+ auto offsets = offset_calc.get (idx);
1019+ void * out = data[0 ] + offsets[0 ];
1020+ arg0_t result = invoke (f, &data[1 ], &offsets[1 ], &dtypes[1 ], 1 );
1021+ c10::cast_and_store<arg0_t >(dtypes[0 ], out, result);
1022+ }
1023+ });
1024+ #else
9971025 launch_legacy_kernel<128 , 4 >(numel, [=] GPU_LAMBDA (int idx) {
9981026 auto offsets = offset_calc.get (idx);
9991027 void * out = data[0 ] + offsets[0 ];
10001028 arg0_t result = invoke (f, &data[1 ], &offsets[1 ], &dtypes[1 ], 1 );
10011029 c10::cast_and_store<arg0_t >(dtypes[0 ], out, result);
10021030 });
1031+ #endif
10031032 }
10041033}
10051034
0 commit comments