@@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create(
5353 infiniopHandle_t handle,
5454 Descriptor **desc_ptr,
5555 infiniopTensorDescriptor_t y_desc,
56+ infiniopTensorDescriptor_t residual_out_desc,
5657 infiniopTensorDescriptor_t a_desc,
5758 infiniopTensorDescriptor_t b_desc,
5859 infiniopTensorDescriptor_t weight_desc,
59- float epsilon,
60- infiniopTensorDescriptor_t residual_out_desc) {
61- auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
60+ float epsilon) {
61+ auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
6262 CHECK_RESULT(result);
6363 auto info = result.take();
6464
@@ -104,16 +104,16 @@ infiniStatus_t launchKernel(
104104 // Handle different data type combinations following Metax pattern
105105 if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
106106 LAUNCH_KERNEL(half, half, float);
107- } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
108- LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
109- } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
110- LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
111- } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
112- LAUNCH_KERNEL(half, float, float);
113107 } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
114108 LAUNCH_KERNEL(half, __hpcc_bfloat16, float);
109+ } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
110+ LAUNCH_KERNEL(half, float, float);
111+ } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
112+ LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
115113 } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
116114 LAUNCH_KERNEL(__hpcc_bfloat16, half, float);
115+ } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
116+ LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
117117 } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
118118 LAUNCH_KERNEL(float, float, float);
119119 } else {
@@ -128,8 +128,8 @@ infiniStatus_t launchKernel(
128128// Main calculation function
129129infiniStatus_t Descriptor::calculate(
130130 void *workspace, size_t workspace_size,
131- void *y, const void *a, const void *b, const void *weight,
132- void *residual_out, void *stream_ ) const {
131+ void *y, void *residual_out, const void *a, const void *b, const void *weight,
132+ void *stream ) const {
133133
134134 // Check workspace size
135135 if (workspace_size < _workspace_size) {
@@ -148,17 +148,41 @@ infiniStatus_t Descriptor::calculate(
148148 auto dim = _info.dim();
149149 uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
150150 size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
151- auto stream = reinterpret_cast<hcStream_t>(stream_ );
151+ auto stream_ = reinterpret_cast<hcStream_t>(stream );
152152
153- // Launch kernel with appropriate block size based on device capability
154- if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
153+ // Launch kernel with different block sizes
154+ if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
155+ CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_512>(
156+ batch_size, nhead, dim,
157+ y, _info.atype, stride_y_batch, stride_y_nhead,
158+ residual_out, stride_residual_out_batch, stride_residual_out_nhead,
159+ a, stride_a_batch, stride_a_nhead,
160+ b, stride_b_batch, stride_b_nhead,
161+ weight, _info.wtype, _info.epsilon, stream_));
162+ } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
155163 CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
156164 batch_size, nhead, dim,
157165 y, _info.atype, stride_y_batch, stride_y_nhead,
158166 residual_out, stride_residual_out_batch, stride_residual_out_nhead,
159167 a, stride_a_batch, stride_a_nhead,
160168 b, stride_b_batch, stride_b_nhead,
161- weight, _info.wtype, _info.epsilon, stream));
169+ weight, _info.wtype, _info.epsilon, stream_));
170+ } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
171+ CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_2048>(
172+ batch_size, nhead, dim,
173+ y, _info.atype, stride_y_batch, stride_y_nhead,
174+ residual_out, stride_residual_out_batch, stride_residual_out_nhead,
175+ a, stride_a_batch, stride_a_nhead,
176+ b, stride_b_batch, stride_b_nhead,
177+ weight, _info.wtype, _info.epsilon, stream_));
178+ } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
179+ CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_4096>(
180+ batch_size, nhead, dim,
181+ y, _info.atype, stride_y_batch, stride_y_nhead,
182+ residual_out, stride_residual_out_batch, stride_residual_out_nhead,
183+ a, stride_a_batch, stride_a_nhead,
184+ b, stride_b_batch, stride_b_nhead,
185+ weight, _info.wtype, _info.epsilon, stream_));
162186 } else {
163187 return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
164188 }
0 commit comments