-
Notifications
You must be signed in to change notification settings - Fork 97
issue/884 - add_rms_norm on iluvatar, metax and moore #898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
wooway777
wants to merge
1
commit into
main
Choose a base branch
from
issue/884
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| #ifndef __ADD_RMS_NORM_METAX_CUH__ | ||
| #define __ADD_RMS_NORM_METAX_CUH__ | ||
|
|
||
| #include "../add_rms_norm.h" | ||
|
|
||
| DESCRIPTOR(metax) | ||
|
|
||
| #endif |
167 changes: 167 additions & 0 deletions
167
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,167 @@ | ||
| #include "../../../devices/metax/metax_common.h" | ||
| #include "add_rms_norm_metax.cuh" | ||
|
|
||
| #include "../../../devices/metax/metax_kernel_common.h" | ||
| #include <cub/block/block_reduce.cuh> | ||
|
|
||
| #include "../../../reduce/cuda/reduce.cuh" | ||
|
|
||
| #include "../cuda/kernel.cuh" | ||
|
|
||
| // Kernel function template for add_rms_norm on Metax platform | ||
| template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight> | ||
| INFINIOP_METAX_KERNEL add_rmsnormKernel( | ||
| Tdata *__restrict__ y, | ||
| Tdata *__restrict__ residual_out, | ||
| ptrdiff_t stride_y_batch, | ||
| ptrdiff_t stride_y_nhead, | ||
| ptrdiff_t stride_residual_out_batch, | ||
| ptrdiff_t stride_residual_out_nhead, | ||
| const Tdata *__restrict__ a, | ||
| ptrdiff_t stride_a_batch, | ||
| ptrdiff_t stride_a_nhead, | ||
| const Tdata *__restrict__ b, | ||
| ptrdiff_t stride_b_batch, | ||
| ptrdiff_t stride_b_nhead, | ||
| const Tweight *__restrict__ w, | ||
| size_t nhead, | ||
| size_t dim, | ||
| float epsilon) { | ||
| add_rmsnormBlock<BLOCK_SIZE, Tcompute>( | ||
| y, residual_out, | ||
| stride_y_batch, stride_y_nhead, | ||
| stride_residual_out_batch, stride_residual_out_nhead, | ||
| a, stride_a_batch, stride_a_nhead, | ||
| b, stride_b_batch, stride_b_nhead, | ||
| w, nhead, dim, epsilon); | ||
| } | ||
|
|
||
| namespace op::add_rms_norm::metax { | ||
|
|
||
| // Internal opaque structure for Metax device handle | ||
| struct Descriptor::Opaque { | ||
| std::shared_ptr<device::metax::Handle::Internal> internal; | ||
| }; | ||
|
|
||
| // Destructor | ||
| Descriptor::~Descriptor() { | ||
| delete _opaque; | ||
| } | ||
|
|
||
| // Create descriptor for add_rms_norm operator | ||
| infiniStatus_t Descriptor::create( | ||
| infiniopHandle_t handle, | ||
| Descriptor **desc_ptr, | ||
| infiniopTensorDescriptor_t y_desc, | ||
| infiniopTensorDescriptor_t a_desc, | ||
| infiniopTensorDescriptor_t b_desc, | ||
| infiniopTensorDescriptor_t weight_desc, | ||
| float epsilon, | ||
| infiniopTensorDescriptor_t residual_out_desc) { | ||
| auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); | ||
| CHECK_RESULT(result); | ||
| auto info = result.take(); | ||
|
|
||
| *desc_ptr = new Descriptor( | ||
| new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()}, | ||
| std::move(info), | ||
| 0, | ||
| handle->device, handle->device_id); | ||
| return INFINI_STATUS_SUCCESS; | ||
| } | ||
|
|
||
| // Launch kernel with different data types | ||
| template <unsigned int BLOCK_SIZE> | ||
| infiniStatus_t launchKernel( | ||
| uint32_t batch_size, size_t nhead, size_t dim, | ||
| void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead, | ||
| void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead, | ||
| const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead, | ||
| const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead, | ||
| const void *w, infiniDtype_t wtype, | ||
| float epsilon, | ||
| hcStream_t stream) { | ||
|
|
||
| #define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ | ||
| add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, stream>>>( \ | ||
| reinterpret_cast<Tdata *>(y), \ | ||
| reinterpret_cast<Tdata *>(residual_out), \ | ||
| stride_y_batch, \ | ||
| stride_y_nhead, \ | ||
| stride_residual_out_batch, \ | ||
| stride_residual_out_nhead, \ | ||
| reinterpret_cast<const Tdata *>(a), \ | ||
| stride_a_batch, \ | ||
| stride_a_nhead, \ | ||
| reinterpret_cast<const Tdata *>(b), \ | ||
| stride_b_batch, \ | ||
| stride_b_nhead, \ | ||
| reinterpret_cast<const Tweight *>(w), \ | ||
| nhead, \ | ||
| dim, \ | ||
| epsilon) | ||
|
|
||
| // Handle different data type combinations following Metax pattern | ||
| if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { | ||
| LAUNCH_KERNEL(half, half, float); | ||
| } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { | ||
| LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float); | ||
| } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { | ||
| LAUNCH_KERNEL(__hpcc_bfloat16, float, float); | ||
| } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { | ||
| LAUNCH_KERNEL(half, float, float); | ||
| } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) { | ||
| LAUNCH_KERNEL(half, __hpcc_bfloat16, float); | ||
| } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) { | ||
| LAUNCH_KERNEL(__hpcc_bfloat16, half, float); | ||
| } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { | ||
| LAUNCH_KERNEL(float, float, float); | ||
| } else { | ||
| return INFINI_STATUS_BAD_TENSOR_DTYPE; | ||
| } | ||
|
|
||
| #undef LAUNCH_KERNEL | ||
|
|
||
| return INFINI_STATUS_SUCCESS; | ||
| } | ||
|
|
||
| // Main calculation function | ||
| infiniStatus_t Descriptor::calculate( | ||
| void *workspace, size_t workspace_size, | ||
| void *y, const void *a, const void *b, const void *weight, | ||
| void *residual_out, void *stream_) const { | ||
|
|
||
| // Check workspace size | ||
| if (workspace_size < _workspace_size) { | ||
| return INFINI_STATUS_INSUFFICIENT_WORKSPACE; | ||
| } | ||
|
|
||
| // Extract tensor strides and dimensions | ||
| auto stride_a_batch = _info.a_strides[0]; | ||
| auto stride_a_nhead = _info.a_strides[1]; | ||
| auto stride_b_batch = _info.b_strides[0]; | ||
| auto stride_b_nhead = _info.b_strides[1]; | ||
| auto stride_y_batch = _info.y_strides[0]; | ||
| auto stride_y_nhead = _info.y_strides[1]; | ||
| auto stride_residual_out_batch = _info.residual_out_strides[0]; | ||
| auto stride_residual_out_nhead = _info.residual_out_strides[1]; | ||
| auto dim = _info.dim(); | ||
| uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]); | ||
| size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; | ||
| auto stream = reinterpret_cast<hcStream_t>(stream_); | ||
|
|
||
| // Launch kernel with appropriate block size based on device capability | ||
| if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { | ||
| CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>( | ||
| batch_size, nhead, dim, | ||
| y, _info.atype, stride_y_batch, stride_y_nhead, | ||
| residual_out, stride_residual_out_batch, stride_residual_out_nhead, | ||
| a, stride_a_batch, stride_a_nhead, | ||
| b, stride_b_batch, stride_b_nhead, | ||
| weight, _info.wtype, _info.epsilon, stream)); | ||
| } else { | ||
| return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; | ||
| } | ||
| return INFINI_STATUS_SUCCESS; | ||
| } | ||
| } // namespace op::add_rms_norm::metax |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| #ifndef __ADD_RMS_NORM_MOORE_H__ | ||
| #define __ADD_RMS_NORM_MOORE_H__ | ||
|
|
||
| #include "../add_rms_norm.h" | ||
|
|
||
| DESCRIPTOR(moore) | ||
|
|
||
| #endif |
183 changes: 183 additions & 0 deletions
183
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,183 @@ | ||
| #include "../../../devices/moore/moore_common.h" | ||
| #include "add_rms_norm_moore.h" | ||
|
|
||
| #include "../../../devices/moore/moore_kernel_common.h" | ||
| #include <cub/block/block_reduce.cuh> | ||
|
|
||
| #include "../../../reduce/cuda/reduce.cuh" | ||
|
|
||
| #include "../cuda/kernel.cuh" | ||
|
|
||
| // Kernel function template for add_rms_norm on Moore platform | ||
| template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight> | ||
| INFINIOP_MOORE_KERNEL add_rmsnormKernel( | ||
| Tdata *__restrict__ y, | ||
| Tdata *__restrict__ residual_out, | ||
| ptrdiff_t stride_y_batch, | ||
| ptrdiff_t stride_y_nhead, | ||
| ptrdiff_t stride_residual_out_batch, | ||
| ptrdiff_t stride_residual_out_nhead, | ||
| const Tdata *__restrict__ a, | ||
| ptrdiff_t stride_a_batch, | ||
| ptrdiff_t stride_a_nhead, | ||
| const Tdata *__restrict__ b, | ||
| ptrdiff_t stride_b_batch, | ||
| ptrdiff_t stride_b_nhead, | ||
| const Tweight *__restrict__ w, | ||
| size_t nhead, | ||
| size_t dim, | ||
| float epsilon) { | ||
| add_rmsnormBlock<BLOCK_SIZE, Tcompute>( | ||
| y, residual_out, | ||
| stride_y_batch, stride_y_nhead, | ||
| stride_residual_out_batch, stride_residual_out_nhead, | ||
| a, stride_a_batch, stride_a_nhead, | ||
| b, stride_b_batch, stride_b_nhead, | ||
| w, nhead, dim, epsilon); | ||
| } | ||
|
|
||
| namespace op::add_rms_norm::moore { | ||
|
|
||
| // Internal opaque structure for Moore device handle | ||
| struct Descriptor::Opaque { | ||
| std::shared_ptr<device::moore::Handle::Internal> internal; | ||
| }; | ||
|
|
||
| // Destructor | ||
| Descriptor::~Descriptor() { | ||
| delete _opaque; | ||
| } | ||
|
|
||
| // Create descriptor for add_rms_norm operator | ||
| infiniStatus_t Descriptor::create( | ||
| infiniopHandle_t handle, | ||
| Descriptor **desc_ptr, | ||
| infiniopTensorDescriptor_t y_desc, | ||
| infiniopTensorDescriptor_t a_desc, | ||
| infiniopTensorDescriptor_t b_desc, | ||
| infiniopTensorDescriptor_t weight_desc, | ||
| float epsilon, | ||
| infiniopTensorDescriptor_t residual_out_desc) { | ||
| auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); | ||
| CHECK_RESULT(result); | ||
| auto info = result.take(); | ||
|
|
||
| *desc_ptr = new Descriptor( | ||
| new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()}, | ||
| std::move(info), | ||
| 0, | ||
| handle->device, handle->device_id); | ||
| return INFINI_STATUS_SUCCESS; | ||
| } | ||
|
|
||
| // Launch kernel with different data types | ||
| template <unsigned int BLOCK_SIZE> | ||
| infiniStatus_t launchKernel( | ||
| uint32_t batch_size, size_t nhead, size_t dim, | ||
| void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead, | ||
| void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead, | ||
| const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead, | ||
| const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead, | ||
| const void *w, infiniDtype_t wtype, | ||
| float epsilon, | ||
| musaStream_t musa_stream) { | ||
|
|
||
| #define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ | ||
| add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, musa_stream>>>( \ | ||
| reinterpret_cast<Tdata *>(y), \ | ||
| reinterpret_cast<Tdata *>(residual_out), \ | ||
| stride_y_batch, \ | ||
| stride_y_nhead, \ | ||
| stride_residual_out_batch, \ | ||
| stride_residual_out_nhead, \ | ||
| reinterpret_cast<const Tdata *>(a), \ | ||
| stride_a_batch, \ | ||
| stride_a_nhead, \ | ||
| reinterpret_cast<const Tdata *>(b), \ | ||
| stride_b_batch, \ | ||
| stride_b_nhead, \ | ||
| reinterpret_cast<const Tweight *>(w), \ | ||
| nhead, \ | ||
| dim, \ | ||
| epsilon) | ||
|
|
||
| // Handle different data type combinations | ||
| if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { | ||
| LAUNCH_KERNEL(half, half, float); | ||
| } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) { | ||
| LAUNCH_KERNEL(half, __mt_bfloat16, float); | ||
| } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { | ||
| LAUNCH_KERNEL(half, float, float); | ||
| } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { | ||
| LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float); | ||
| } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) { | ||
| LAUNCH_KERNEL(__mt_bfloat16, half, float); | ||
| } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { | ||
| LAUNCH_KERNEL(__mt_bfloat16, float, float); | ||
| } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { | ||
| LAUNCH_KERNEL(float, float, float); | ||
| } else { | ||
| return INFINI_STATUS_BAD_TENSOR_DTYPE; | ||
| } | ||
|
|
||
| #undef LAUNCH_KERNEL | ||
|
|
||
| return INFINI_STATUS_SUCCESS; | ||
| } | ||
|
|
||
| // Main calculation function | ||
| infiniStatus_t Descriptor::calculate( | ||
| void *workspace, size_t workspace_size, | ||
| void *y, const void *a, const void *b, const void *weight, | ||
| void *residual_out, void *stream) const { | ||
|
|
||
| // Check workspace size | ||
| if (workspace_size < _workspace_size) { | ||
| return INFINI_STATUS_INSUFFICIENT_WORKSPACE; | ||
| } | ||
|
|
||
| // Extract tensor strides and dimensions | ||
| auto stride_a_batch = _info.a_strides[0]; | ||
| auto stride_a_nhead = _info.a_strides[1]; | ||
| auto stride_b_batch = _info.b_strides[0]; | ||
| auto stride_b_nhead = _info.b_strides[1]; | ||
| auto stride_y_batch = _info.y_strides[0]; | ||
| auto stride_y_nhead = _info.y_strides[1]; | ||
| auto stride_residual_out_batch = _info.residual_out_strides[0]; | ||
| auto stride_residual_out_nhead = _info.residual_out_strides[1]; | ||
| auto dim = _info.dim(); | ||
| uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]); | ||
| size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; | ||
| auto musa_stream = reinterpret_cast<musaStream_t>(stream); | ||
|
|
||
| // Launch kernel with appropriate block size based on device capability | ||
| if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) { | ||
| CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>( | ||
| batch_size, nhead, dim, | ||
| y, _info.atype, stride_y_batch, stride_y_nhead, | ||
| residual_out, stride_residual_out_batch, stride_residual_out_nhead, | ||
| a, stride_a_batch, stride_a_nhead, | ||
| b, stride_b_batch, stride_b_nhead, | ||
| weight, _info.wtype, _info.epsilon, musa_stream)); | ||
| } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) { | ||
| CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>( | ||
| batch_size, nhead, dim, | ||
| y, _info.atype, stride_y_batch, stride_y_nhead, | ||
| residual_out, stride_residual_out_batch, stride_residual_out_nhead, | ||
| a, stride_a_batch, stride_a_nhead, | ||
| b, stride_b_batch, stride_b_nhead, | ||
| weight, _info.wtype, _info.epsilon, musa_stream)); | ||
| } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) { | ||
| CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>( | ||
| batch_size, nhead, dim, | ||
| y, _info.atype, stride_y_batch, stride_y_nhead, | ||
| residual_out, stride_residual_out_batch, stride_residual_out_nhead, | ||
| a, stride_a_batch, stride_a_nhead, | ||
| b, stride_b_batch, stride_b_nhead, | ||
| weight, _info.wtype, _info.epsilon, musa_stream)); | ||
| } else { | ||
| return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; | ||
| } | ||
| return INFINI_STATUS_SUCCESS; | ||
| } | ||
| } // namespace op::add_rms_norm::moore | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.