Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ compile_commands.json
.nfs
tensor_dumps/
artifacts/
*.DS_Store
5 changes: 3 additions & 2 deletions transformer_engine/common/include/transformer_engine/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
NVTETensor workplace, cudaStream_t stream);
NVTETensor workspace, cudaStream_t stream);

/*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the GeLU backward along columns.
Expand Down Expand Up @@ -263,7 +263,8 @@ void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t str
*
* \param[in] inputs List of input tensors to be cast.
* \param[in,out] outputs List of output quantized tensors.
* \param[in] quant_config (Optional) Quantization configurations.
* \param[in] quant_config (Optional) Quantization configurations.
* \param[in] num_tensors Number of input and output tensors.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
* See LICENSE for license information.
************************************************************************/

/*! \file transpose_with_noop.h
* \brief Functions handling transposes with no-op.
/*! \file cast_transpose_noop.h
* \brief Transpose functions with no-op flag.
*/

#ifndef TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,6 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
Expand Down Expand Up @@ -673,7 +672,7 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
* \param[in] len batch_size x sequence_length.
* \param[in] stream CUDA stream used for this operation.
*/
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlens, NVTETensor workspace, size_t len,
cudaStream_t stream);

/*! \brief Set the seed and offset for RNG state.
Expand Down Expand Up @@ -830,8 +829,7 @@ void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens
* \param[in] tensor Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] new_tensor Output tensor.
* \param[in] b Batch size.
* \param[in] max_seq_len Maximum sequence length.
* \param[in] t Packed sequence length.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/include/transformer_engine/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,11 @@ class MatmulConfigWrapper {
MatmulConfigWrapper(const MatmulConfigWrapper &) = delete;
MatmulConfigWrapper &operator=(const MatmulConfigWrapper &) = delete;

/*! \brief Move constructor. */
MatmulConfigWrapper(MatmulConfigWrapper &&other) : config_{other.config_} {
other.config_ = nullptr;
}
/*! \brief Move-assignment operator. */
MatmulConfigWrapper &operator=(MatmulConfigWrapper &&other) {
if (config_ != nullptr) {
nvte_destroy_matmul_config(config_);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor **
* and populate the amax of the corresponding output tensor.
*
* \param[in] input Input tensor.
* \param[in,out] amaxes Array of output tensors. Only the amax is updated.
* \param[in,out] outputs Array of output tensors. Only the amax is updated.
* \param[in] split_sections Size of each tensor split along dimension 0.
* \param[in] num_tensors Number of tensor splits.
* \param[in] stream CUDA stream used for the operation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,16 @@ void nvte_rmsnorm_bwd_add(const NVTETensor dz, const NVTETensor x, const NVTETen
NVTETensor dgamma, NVTETensor workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream);

/*! \brief Helper to enable cuDNN backend for normalization
/*! \brief Set whether to enable cuDNN backend for normalization forward.
*
* \param[in] bool Enable if True
* \param[in] enable Whether to enable cuDNN backend.
*/
void nvte_enable_cudnn_norm_fwd(bool enable);

/*! \brief Set whether to enable cuDNN backend for normalization backward.
*
* \param[in] enable Whether to enable cuDNN backend.
*/
void nvte_enable_cudnn_norm_bwd(bool enable);

/*! \brief Control whether norm computes `gamma += 1.0` for zero-centered gamma
Expand All @@ -176,11 +181,14 @@ void nvte_enable_cudnn_norm_bwd(bool enable);
* Currently this only applies to the CuDNN backend. If CuDNN is not used,
* this setting has no effect.
*
* \param[in] bool Enable if True
* \param[in] enable Whether to enable zero-centered gamma.
*/
void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable);

#ifdef __cplusplus
/*! \brief Normalization function type */
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
#endif

#ifdef __cplusplus
} // extern "C"
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/include/transformer_engine/recipe.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ extern "C" {
* the last, the last entry shifts to the second to last) and the
* first entry is set to zero. The scaling factor is estimated so the
* FP8 tensor's maximum absolute value is
* @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

\text requires amsmath.

* @f$ 2^{-margin} \max_{fp8\_dtype} @f$.
*
* \param[in] amax_history History of maximum absolute values.
* Shape: [history_length, num_scales]
Expand Down Expand Up @@ -54,7 +54,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(
* the last, the last entry shifts to the second to last) and the
* first entry is set to zero. The scaling factor is estimated so the
* FP8 tensor's maximum absolute value is
* @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$.
* @f$ 2^{-margin} \max_{fp8\_dtype} @f$.
*
* \param[in] amax_reduction_buffer The contiguous buffer used for amax reduction.
* Shape: [num_scales * num_tensors]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
*
* \param[in] inputs Input tensors with non-swizzled scale_inv.
* \param[in,out] outputs Output tensors which hosts swizzled scale_inv.
* \param[in] num_tensors Number of input and output tensors.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@ struct NVTEShape {
* It does not own the memory it points to.
*/
struct NVTEBasicTensor {
/*! Pointer to data buffer. */
void *data_ptr;
/*! Data type. */
NVTEDType dtype;
/*! Tensor shape. */
NVTEShape shape;
};

Expand Down Expand Up @@ -144,7 +147,7 @@ void *nvte_tensor_columnwise_data(const NVTETensor tensor);
*
* \param[data] Pointer to start of shape array. If NULL, the shape
* will be filled with zeros.
* \param[data] Number of dimensions (must be <= 14)
* \param[ndim] Number of dimensions (must be <= 14)
*
* \return A shape. The shape will own its own copy of the data.
*/
Expand Down Expand Up @@ -177,7 +180,7 @@ size_t nvte_tensor_ndims(const NVTETensor tensor);
/*! \brief Get the size of a specific tensor dimension.
*
* \param[in] tensor Tensor.
* \param[in] size_t Dimension index.
* \param[in] dim Dimension index.
*
* \return Size of the tensor at the specified dimension.
*/
Expand Down Expand Up @@ -258,8 +261,7 @@ NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor);
/*! \brief Reset tensor value to zero.
*
* \param[in] tensor Tensor.
*
* \return A scale_inv shape of the input tensor.
* \param[in] stream CUDA stream to use for the operation.
*/
void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream);

Expand Down Expand Up @@ -539,7 +541,7 @@ enum class DType {
/*! \brief Check if TE datatype is FP8
*
* Return true if TE datatype is FP8
* \param[in] DType TE Datatype of interest
* \param[in] t TE Datatype of interest
*/
inline bool is_fp8_dtype(const DType t) {
return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2;
Expand All @@ -548,14 +550,14 @@ inline bool is_fp8_dtype(const DType t) {
/*! \brief Check if TE datatype is FP4
*
* Return true if TE datatype is FP4
* \param[in] DType TE Datatype of interest
* \param[in] t TE Datatype of interest
*/
inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; }

/*! \brief Check if TE datatype is high precision (FP32, FP16, BF16)
*
* Return true if TE datatype is high precision
* \param[in] DType TE Datatype of interest
* \param[in] t TE Datatype of interest
*/
inline bool is_high_precision_dtype(const DType t) {
return t == DType::kFloat32 || t == DType::kBFloat16 || t == DType::kFloat16;
Expand All @@ -579,6 +581,7 @@ class TensorWrapper {
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_shape Shape of scale_inv
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
* \param[in] scaling_mode Tensor data format.
*/
TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr,
float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr,
Expand Down Expand Up @@ -615,6 +618,7 @@ class TensorWrapper {
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_shape Shape of scale_inv
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
* \param[in] scaling_mode Tensor data format.
*/
TensorWrapper(void *dptr, const std::vector<size_t> &shape, const DType dtype,
float *amax_dptr = nullptr, float *scale_dptr = nullptr,
Expand Down Expand Up @@ -766,7 +770,7 @@ class TensorWrapper {

/*! \brief Get the size of this TensorWrapper in the given dimension.
*
* \param[in] size_t Dimension index.
* \param[in] dim Dimension index.
*
* \return Size of this TensorWrapper in given dimension.
*/
Expand Down Expand Up @@ -935,9 +939,11 @@ class QuantizationConfigWrapper {
QuantizationConfigWrapper(const QuantizationConfigWrapper &) = delete;
QuantizationConfigWrapper &operator=(const QuantizationConfigWrapper &) = delete;

/*! \brief Move constructor. */
QuantizationConfigWrapper(QuantizationConfigWrapper &&other) : config_{other.config_} {
other.config_ = nullptr;
}
/*! \brief Move-assignment operator. */
QuantizationConfigWrapper &operator=(QuantizationConfigWrapper &&other) {
if (config_ != nullptr) {
nvte_destroy_quantization_config(config_);
Expand Down
10 changes: 5 additions & 5 deletions transformer_engine/common/include/transformer_engine/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor a
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of
* \param[in] act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
Expand All @@ -250,7 +250,7 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_inp
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of
* \param[in] act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
Expand All @@ -269,7 +269,7 @@ void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor act_in
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of
* \param[in] act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
Expand All @@ -288,7 +288,7 @@ void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor act_inp
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of
* \param[in] act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
Expand All @@ -307,7 +307,7 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_in
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of
* \param[in] act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
Expand Down
Loading