Skip to content

Commit c00ef4b

Browse files
committed
issue/791 - fix add_rmsnorm api on mtx and mth
1 parent 264d349 commit c00ef4b

File tree

4 files changed

+48
-21
lines changed

4 files changed

+48
-21
lines changed

src/infiniop/devices/metax/metax_kernel_common.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88

99
// Posible maximum number of threads per block for METAX architectures
1010
// Used for picking correct kernel launch configuration
11-
#define METAX_BLOCK_SIZE_1024 1024
1211
#define METAX_BLOCK_SIZE_512 512
12+
#define METAX_BLOCK_SIZE_1024 1024
13+
#define METAX_BLOCK_SIZE_2048 2048
14+
#define METAX_BLOCK_SIZE_4096 4096
1315

1416
#define CHECK_METAX(API) CHECK_INTERNAL(API, hcSuccess)
1517

src/infiniop/devices/moore/moore_kernel_common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
// Posible maximum number of threads per block for MUSA architectures
88
// Used for picking correct kernel launch configuration
9+
#define MOORE_BLOCK_SIZE_4096 4096
910
#define MOORE_BLOCK_SIZE_2048 2048
1011
#define MOORE_BLOCK_SIZE_1024 1024
1112
#define MOORE_BLOCK_SIZE_512 512

src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create(
5353
infiniopHandle_t handle,
5454
Descriptor **desc_ptr,
5555
infiniopTensorDescriptor_t y_desc,
56+
infiniopTensorDescriptor_t residual_out_desc,
5657
infiniopTensorDescriptor_t a_desc,
5758
infiniopTensorDescriptor_t b_desc,
5859
infiniopTensorDescriptor_t weight_desc,
59-
float epsilon,
60-
infiniopTensorDescriptor_t residual_out_desc) {
61-
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
60+
float epsilon) {
61+
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
6262
CHECK_RESULT(result);
6363
auto info = result.take();
6464

@@ -104,16 +104,16 @@ infiniStatus_t launchKernel(
104104
// Handle different data type combinations following Metax pattern
105105
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
106106
LAUNCH_KERNEL(half, half, float);
107-
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
108-
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
109-
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
110-
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
111-
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
112-
LAUNCH_KERNEL(half, float, float);
113107
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
114108
LAUNCH_KERNEL(half, __hpcc_bfloat16, float);
109+
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
110+
LAUNCH_KERNEL(half, float, float);
111+
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
112+
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
115113
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
116114
LAUNCH_KERNEL(__hpcc_bfloat16, half, float);
115+
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
116+
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
117117
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
118118
LAUNCH_KERNEL(float, float, float);
119119
} else {
@@ -128,8 +128,8 @@ infiniStatus_t launchKernel(
128128
// Main calculation function
129129
infiniStatus_t Descriptor::calculate(
130130
void *workspace, size_t workspace_size,
131-
void *y, const void *a, const void *b, const void *weight,
132-
void *residual_out, void *stream_) const {
131+
void *y, void *residual_out, const void *a, const void *b, const void *weight,
132+
void *stream) const {
133133

134134
// Check workspace size
135135
if (workspace_size < _workspace_size) {
@@ -148,17 +148,41 @@ infiniStatus_t Descriptor::calculate(
148148
auto dim = _info.dim();
149149
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
150150
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
151-
auto stream = reinterpret_cast<hcStream_t>(stream_);
151+
auto stream_ = reinterpret_cast<hcStream_t>(stream);
152152

153-
// Launch kernel with appropriate block size based on device capability
154-
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
153+
// Launch kernel with different block sizes
154+
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
155+
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_512>(
156+
batch_size, nhead, dim,
157+
y, _info.atype, stride_y_batch, stride_y_nhead,
158+
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
159+
a, stride_a_batch, stride_a_nhead,
160+
b, stride_b_batch, stride_b_nhead,
161+
weight, _info.wtype, _info.epsilon, stream_));
162+
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
155163
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
156164
batch_size, nhead, dim,
157165
y, _info.atype, stride_y_batch, stride_y_nhead,
158166
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
159167
a, stride_a_batch, stride_a_nhead,
160168
b, stride_b_batch, stride_b_nhead,
161-
weight, _info.wtype, _info.epsilon, stream));
169+
weight, _info.wtype, _info.epsilon, stream_));
170+
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
171+
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_2048>(
172+
batch_size, nhead, dim,
173+
y, _info.atype, stride_y_batch, stride_y_nhead,
174+
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
175+
a, stride_a_batch, stride_a_nhead,
176+
b, stride_b_batch, stride_b_nhead,
177+
weight, _info.wtype, _info.epsilon, stream_));
178+
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
179+
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_4096>(
180+
batch_size, nhead, dim,
181+
y, _info.atype, stride_y_batch, stride_y_nhead,
182+
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
183+
a, stride_a_batch, stride_a_nhead,
184+
b, stride_b_batch, stride_b_nhead,
185+
weight, _info.wtype, _info.epsilon, stream_));
162186
} else {
163187
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
164188
}

src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create(
5353
infiniopHandle_t handle,
5454
Descriptor **desc_ptr,
5555
infiniopTensorDescriptor_t y_desc,
56+
infiniopTensorDescriptor_t residual_out_desc,
5657
infiniopTensorDescriptor_t a_desc,
5758
infiniopTensorDescriptor_t b_desc,
5859
infiniopTensorDescriptor_t weight_desc,
59-
float epsilon,
60-
infiniopTensorDescriptor_t residual_out_desc) {
61-
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
60+
float epsilon) {
61+
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
6262
CHECK_RESULT(result);
6363
auto info = result.take();
6464

@@ -128,8 +128,8 @@ infiniStatus_t launchKernel(
128128
// Main calculation function
129129
infiniStatus_t Descriptor::calculate(
130130
void *workspace, size_t workspace_size,
131-
void *y, const void *a, const void *b, const void *weight,
132-
void *residual_out, void *stream) const {
131+
void *y, void *residual_out, const void *a, const void *b, const void *weight,
132+
void *stream) const {
133133

134134
// Check workspace size
135135
if (workspace_size < _workspace_size) {

0 commit comments

Comments
 (0)