-
Notifications
You must be signed in to change notification settings - Fork 22
GEMM reference computation offload #392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 2 commits
ad748da
11e090b
9006224
3ecea7f
cafee59
86fbbac
54de3db
311ddfe
306e432
445e64f
462945f
e32fb3d
7bf8adb
e11e400
325ece6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -51,11 +51,248 @@ using TShape = std::vector<size_t>; | |||||||||||||||||||||||||||||||||||||
| } // namespace | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| float ref_gelu(float x){ | ||||||||||||||||||||||||||||||||||||||
| __device__ __host__ __forceinline__ float ref_gelu(float x){ | ||||||||||||||||||||||||||||||||||||||
| float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); | ||||||||||||||||||||||||||||||||||||||
| return x * cdf; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| template <typename A_Type, typename B_Type, typename Bias_Type, | ||||||||||||||||||||||||||||||||||||||
| typename Gelu_Type, typename D_Type> | ||||||||||||||||||||||||||||||||||||||
| __global__ void compute_ref_kernel( | ||||||||||||||||||||||||||||||||||||||
| const A_Type* __restrict__ a_data, | ||||||||||||||||||||||||||||||||||||||
| const B_Type* __restrict__ b_data, | ||||||||||||||||||||||||||||||||||||||
| float a_scale_inv_scalar, // used when mxfp8 == false | ||||||||||||||||||||||||||||||||||||||
| float b_scale_inv_scalar, | ||||||||||||||||||||||||||||||||||||||
| const fp8e8m0* __restrict__ a_scale_inv_mxfp8, // used when mxfp8 == true | ||||||||||||||||||||||||||||||||||||||
| const fp8e8m0* __restrict__ b_scale_inv_mxfp8, | ||||||||||||||||||||||||||||||||||||||
| const Bias_Type* __restrict__ bias_data, | ||||||||||||||||||||||||||||||||||||||
| float d_scale, | ||||||||||||||||||||||||||||||||||||||
| size_t m, size_t k, size_t n, | ||||||||||||||||||||||||||||||||||||||
| D_Type* __restrict__ d_data, | ||||||||||||||||||||||||||||||||||||||
| float* __restrict__ d_amax, | ||||||||||||||||||||||||||||||||||||||
| Gelu_Type* __restrict__ gelu_data, | ||||||||||||||||||||||||||||||||||||||
| bool transa, | ||||||||||||||||||||||||||||||||||||||
| bool transb, | ||||||||||||||||||||||||||||||||||||||
| bool is_fp8_output) | ||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||
| const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; | ||||||||||||||||||||||||||||||||||||||
| const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const bool in_range = (ii < m) && (jj < n); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| float val = 0.0f; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (in_range) { | ||||||||||||||||||||||||||||||||||||||
| for (size_t kk = 0; kk < k; ++kk) { | ||||||||||||||||||||||||||||||||||||||
| const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii); | ||||||||||||||||||||||||||||||||||||||
| const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| float a_scale_inv_val = a_scale_inv_scalar; | ||||||||||||||||||||||||||||||||||||||
| float b_scale_inv_val = b_scale_inv_scalar; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (a_scale_inv_mxfp8) { | ||||||||||||||||||||||||||||||||||||||
| const size_t a_scale_idx = | ||||||||||||||||||||||||||||||||||||||
| transa ? (a_idx / 32) : ((kk / 32) * m + ii); | ||||||||||||||||||||||||||||||||||||||
| const size_t b_scale_idx = | ||||||||||||||||||||||||||||||||||||||
| transb ? ((kk / 32) * n + jj) : (b_idx / 32); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const float a_byte = static_cast<float>(a_scale_inv_mxfp8[a_scale_idx]); | ||||||||||||||||||||||||||||||||||||||
| const float b_byte = static_cast<float>(b_scale_inv_mxfp8[b_scale_idx]); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| a_scale_inv_val = exp2f(a_byte - 127.0f); | ||||||||||||||||||||||||||||||||||||||
| b_scale_inv_val = exp2f(b_byte - 127.0f); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const float a_val = static_cast<float>(a_data[a_idx]); | ||||||||||||||||||||||||||||||||||||||
| const float b_val = static_cast<float>(b_data[b_idx]); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (bias_data) { | ||||||||||||||||||||||||||||||||||||||
| val += static_cast<float>(bias_data[ii]); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (gelu_data) { | ||||||||||||||||||||||||||||||||||||||
| gelu_data[ii + jj * m] = static_cast<Gelu_Type>(val); | ||||||||||||||||||||||||||||||||||||||
| val = ref_gelu(val); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const float scaled = val * d_scale; | ||||||||||||||||||||||||||||||||||||||
| d_data[ii + jj * m] = static_cast<D_Type>(scaled); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // Blockwise reduction for amax | ||||||||||||||||||||||||||||||||||||||
| if (is_fp8_output && d_amax) { | ||||||||||||||||||||||||||||||||||||||
| const int tid = threadIdx.y * blockDim.x + threadIdx.x; | ||||||||||||||||||||||||||||||||||||||
| const int nthreads = blockDim.x * blockDim.y; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| extern __shared__ float s_amax[]; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // Out-of-range threads contribute 0 | ||||||||||||||||||||||||||||||||||||||
| s_amax[tid] = in_range ? fabsf(val) : 0.0f; | ||||||||||||||||||||||||||||||||||||||
| __syncthreads(); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| for (int offset = nthreads / 2; offset > 0; offset /= 2) { | ||||||||||||||||||||||||||||||||||||||
| if (tid < offset) { | ||||||||||||||||||||||||||||||||||||||
| s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| __syncthreads(); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (tid == 0) { | ||||||||||||||||||||||||||||||||||||||
| const float block_max = s_amax[0]; | ||||||||||||||||||||||||||||||||||||||
| atomicMax(d_amax, block_max); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // Common implementation used by both tensor-wise and MXFP8 frontends | ||||||||||||||||||||||||||||||||||||||
| template <typename A_Type, typename B_Type, typename Bias_Type, | ||||||||||||||||||||||||||||||||||||||
| typename Gelu_Type, typename D_Type> | ||||||||||||||||||||||||||||||||||||||
| static void compute_ref_impl( | ||||||||||||||||||||||||||||||||||||||
| const A_Type* a_data, | ||||||||||||||||||||||||||||||||||||||
| const B_Type* b_data, | ||||||||||||||||||||||||||||||||||||||
| float a_scale_inv_scalar, // used when mxfp8 == false | ||||||||||||||||||||||||||||||||||||||
| float b_scale_inv_scalar, | ||||||||||||||||||||||||||||||||||||||
| const fp8e8m0* a_scale_inv_mxfp8, // used when mxfp8 == true | ||||||||||||||||||||||||||||||||||||||
| const fp8e8m0* b_scale_inv_mxfp8, | ||||||||||||||||||||||||||||||||||||||
| const Bias_Type* bias_data, | ||||||||||||||||||||||||||||||||||||||
| float d_scale, | ||||||||||||||||||||||||||||||||||||||
| size_t m, size_t k, size_t n, | ||||||||||||||||||||||||||||||||||||||
| D_Type* d_data, | ||||||||||||||||||||||||||||||||||||||
| float* d_amax_host, | ||||||||||||||||||||||||||||||||||||||
| Gelu_Type* gelu_data, | ||||||||||||||||||||||||||||||||||||||
| bool transa, | ||||||||||||||||||||||||||||||||||||||
| bool transb) | ||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||
| using transformer_engine::DType; | ||||||||||||||||||||||||||||||||||||||
| using ::TypeInfo; | ||||||||||||||||||||||||||||||||||||||
| using ::isFp8Type; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const bool use_mxfp8 = (a_scale_inv_mxfp8 != nullptr); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const DType dtype = TypeInfo<D_Type>::dtype; | ||||||||||||||||||||||||||||||||||||||
| const bool is_fp8_output = isFp8Type(dtype); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const size_t lenA = m * k; | ||||||||||||||||||||||||||||||||||||||
| const size_t lenB = k * n; | ||||||||||||||||||||||||||||||||||||||
| const size_t lenD = m * n; | ||||||||||||||||||||||||||||||||||||||
| const size_t lenBias = m; | ||||||||||||||||||||||||||||||||||||||
| const size_t lenGelu = m * n; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const size_t lenA_scale = use_mxfp8 ? (lenA + 31) / 32 : 0; | ||||||||||||||||||||||||||||||||||||||
wangye805 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||
| const size_t lenB_scale = use_mxfp8 ? (lenB + 31) / 32 : 0; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| A_Type* dA = nullptr; | ||||||||||||||||||||||||||||||||||||||
| B_Type* dB = nullptr; | ||||||||||||||||||||||||||||||||||||||
| Bias_Type* dBias = nullptr; | ||||||||||||||||||||||||||||||||||||||
| D_Type* dD = nullptr; | ||||||||||||||||||||||||||||||||||||||
| Gelu_Type* dGelu = nullptr; | ||||||||||||||||||||||||||||||||||||||
| float* dAmax = nullptr; | ||||||||||||||||||||||||||||||||||||||
| fp8e8m0* dA_scale = nullptr; | ||||||||||||||||||||||||||||||||||||||
| fp8e8m0* dB_scale = nullptr; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // Allocations and H2D transfers | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMalloc(&dA, lenA * sizeof(A_Type))); | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
| Tensor::Tensor(const std::string& name, |
| Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); |
In fact, we can change the api of reference computing by taking directly const tensor& therefore we don't need to re-allocate the input and do one extra copy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think of 3ecea7f? This also merges the mxfp8/non-mxfp8 paths.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for consolidating with existing apis in test_common.cu.
In fact, I still see some cudaMalloc and cudaFree, which can be replaced by using existing test tensor class apis.
For example, the device pointer for scale (
| NVTE_CHECK_CUDA(cudaMalloc(&d_a_scale_packed, a_scale_packed.size() * sizeof(fp8e8m0))); |
TransformerEngine/tests/cpp/test_common.cu
Lines 321 to 335 in 2bc74c8
| if (rowwise) { | |
| (void)cudaMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) | |
| (void)cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size); | |
| rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(rowwise_scale_size); | |
| std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); | |
| auto scale_dtype = rowwise_scale_meta.type; | |
| tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape); | |
| } | |
| if (columnwise) { | |
| (void)cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) | |
| (void)cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size); | |
| columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(columnwise_scale_size); | |
| std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); | |
| auto scale_dtype = colwise_scale_meta.type; | |
| tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I replaced the remaining raw allocations in the reference path with test::Tensor for the temporary device buffers (RefD/RefGelu/RefAmax) in e11e400.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Yeah, it indeed saved some cudaMalloc/cudaFrees.
How about we put the RefD instantiation inside PerformTest, and pass the Tensor RefD (including its RefAmax D) and RefPreGeluOut to run_reference directly (instead of std::unique_ptr<D_Type[]>& ref_D, float* ref_amax_d, std::unique_ptr<Gelu_Type[]>& ref_pre_gelu_out). Then this can save some ref cpu ptr allocation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think of 325ece6?
alextmagro marked this conversation as resolved.
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have both reference and target runs on GPU, we can just run one single device synchronization at the very end of both runs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 325ece6
Uh oh!
There was an error while loading. Please reload this page.