Skip to content

Commit 759e7bb

Browse files
committed
C++ utils, untested
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent cf61339 commit 759e7bb

File tree

2 files changed

+302
-0
lines changed

2 files changed

+302
-0
lines changed

transformer_engine/common/include/transformer_engine/transformer_engine.h

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,210 @@ class TensorWrapper {
919919
NVTETensor tensor_ = nullptr;
920920
};
921921

922+
/*! \struct GroupedTensorWrapper
923+
* \brief C++ wrapper for the NVTEGroupedTensor class.
924+
*/
925+
class GroupedTensorWrapper {
926+
public:
927+
/*! \brief Constructs new GroupedTensorWrapper.
928+
*
929+
* Create a new TE grouped tensor with a given logical shape.
930+
* TE grouped tensors are just wrappers on top of raw data and do not
931+
* own memory.
932+
*
933+
* \param[in] num_tensors Number of tensors in the group (must be > 0).
934+
* \param[in] logical_shape Logical 2D shape of the grouped data.
935+
* \param[in] scaling_mode Tensor data format.
936+
*/
937+
GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape,
938+
const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING)
939+
: tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {}
940+
941+
/*! \brief Constructs new GroupedTensorWrapper.
942+
*
943+
* Create a new TE grouped tensor with a given logical shape.
944+
*
945+
* \param[in] num_tensors Number of tensors in the group (must be > 0).
946+
* \param[in] logical_shape Logical 2D shape of the grouped data.
947+
* \param[in] scaling_mode Tensor data format.
948+
*/
949+
GroupedTensorWrapper(const size_t num_tensors, const std::vector<size_t> &logical_shape,
950+
const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING)
951+
: GroupedTensorWrapper(num_tensors,
952+
nvte_make_shape(logical_shape.data(), logical_shape.size()),
953+
scaling_mode) {}
954+
955+
/*! \brief GroupedTensorWrapper destructor. */
956+
~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); }
957+
958+
GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete;
959+
GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete;
960+
961+
/*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */
962+
GroupedTensorWrapper(GroupedTensorWrapper &&other) {
963+
tensor_ = other.tensor_;
964+
other.tensor_ = nullptr;
965+
}
966+
967+
/*! \brief Assign the data from existing GroupedTensorWrapper. */
968+
GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) {
969+
if (this == &other) return *this;
970+
nvte_destroy_grouped_tensor(tensor_);
971+
tensor_ = other.tensor_;
972+
other.tensor_ = nullptr;
973+
return *this;
974+
}
975+
976+
// Parameter setters
977+
template <typename ShapeType>
978+
GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type,
979+
const ShapeType &shape) noexcept {
980+
NVTEShape nvte_shape = this->convertShape(shape);
981+
NVTEBasicTensor data = {dptr, static_cast<NVTEDType>(type), nvte_shape};
982+
nvte_set_grouped_tensor_param(&tensor_, param, &data);
983+
return *this;
984+
}
985+
986+
template <typename ShapeType>
987+
GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept {
988+
return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape);
989+
}
990+
991+
template <typename ShapeType>
992+
GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type,
993+
const ShapeType &shape) noexcept {
994+
return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape);
995+
}
996+
997+
template <typename ShapeType>
998+
GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept {
999+
return set_parameter(kNVTEGroupedScale, dptr, type, shape);
1000+
}
1001+
1002+
template <typename ShapeType>
1003+
GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept {
1004+
return set_parameter(kNVTEGroupedAmax, dptr, type, shape);
1005+
}
1006+
1007+
template <typename ShapeType>
1008+
GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type,
1009+
const ShapeType &shape) noexcept {
1010+
return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape);
1011+
}
1012+
1013+
template <typename ShapeType>
1014+
GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type,
1015+
const ShapeType &shape) noexcept {
1016+
return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape);
1017+
}
1018+
1019+
template <typename ShapeType>
1020+
GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type,
1021+
const ShapeType &shape) noexcept {
1022+
return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape);
1023+
}
1024+
1025+
template <typename ShapeType>
1026+
GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept {
1027+
return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape);
1028+
}
1029+
1030+
template <typename ShapeType>
1031+
GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept {
1032+
return set_parameter(kNVTEGroupedLastDims, dptr, type, shape);
1033+
}
1034+
1035+
template <typename ShapeType>
1036+
GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type,
1037+
const ShapeType &shape) noexcept {
1038+
return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape);
1039+
}
1040+
1041+
// Parameter getters
1042+
NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept {
1043+
return nvte_get_grouped_tensor_param(tensor_, param);
1044+
}
1045+
1046+
NVTEBasicTensor get_rowwise_data() const noexcept {
1047+
return get_parameter(kNVTEGroupedRowwiseData);
1048+
}
1049+
1050+
NVTEBasicTensor get_columnwise_data() const noexcept {
1051+
return get_parameter(kNVTEGroupedColumnwiseData);
1052+
}
1053+
1054+
NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); }
1055+
1056+
NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); }
1057+
1058+
NVTEBasicTensor get_rowwise_scale_inv() const noexcept {
1059+
return get_parameter(kNVTEGroupedRowwiseScaleInv);
1060+
}
1061+
1062+
NVTEBasicTensor get_columnwise_scale_inv() const noexcept {
1063+
return get_parameter(kNVTEGroupedColumnwiseScaleInv);
1064+
}
1065+
1066+
NVTEBasicTensor get_columnwise_amax() const noexcept {
1067+
return get_parameter(kNVTEGroupedColumnwiseAmax);
1068+
}
1069+
1070+
NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); }
1071+
1072+
NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); }
1073+
1074+
NVTEBasicTensor get_tensor_offsets() const noexcept {
1075+
return get_parameter(kNVTEGroupedTensorOffsets);
1076+
}
1077+
1078+
/*! \brief Get an underlying NVTEGroupedTensor.
1079+
*
1080+
* \return NVTEGroupedTensor held by this GroupedTensorWrapper.
1081+
*/
1082+
NVTEGroupedTensor data() const noexcept { return tensor_; }
1083+
1084+
/*! \brief Get the number of tensors in this GroupedTensorWrapper. */
1085+
size_t num_tensors() const noexcept {
1086+
if (tensor_ == nullptr) return 0;
1087+
return nvte_grouped_tensor_num_tensors(tensor_);
1088+
}
1089+
1090+
/*! \brief Get the data type of this GroupedTensorWrapper. */
1091+
DType dtype() const noexcept {
1092+
if (tensor_ == nullptr) return DType::kNumTypes;
1093+
return static_cast<DType>(nvte_grouped_tensor_type(tensor_));
1094+
}
1095+
1096+
/*! \brief Get a scaling mode of the grouped tensor. */
1097+
NVTEScalingMode scaling_mode() const noexcept {
1098+
if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING;
1099+
return nvte_grouped_tensor_scaling_mode(tensor_);
1100+
}
1101+
1102+
/*! \brief Get the logical shape of this GroupedTensorWrapper. */
1103+
const NVTEShape logical_shape() const noexcept {
1104+
if (tensor_ == nullptr) {
1105+
return emptyShape;
1106+
}
1107+
return nvte_get_grouped_tensor_logical_shape(tensor_);
1108+
}
1109+
1110+
static constexpr size_t defaultData = 1;
1111+
static constexpr NVTEShape defaultShape = {
1112+
{defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1};
1113+
static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1};
1114+
1115+
private:
1116+
NVTEShape convertShape(const NVTEShape &s) { return s; }
1117+
1118+
NVTEShape convertShape(const std::vector<size_t> &s) {
1119+
return nvte_make_shape(s.data(), s.size());
1120+
}
1121+
1122+
/*! \brief Wrapped NVTEGroupedTensor. */
1123+
NVTEGroupedTensor tensor_ = nullptr;
1124+
};
1125+
9221126
/*! \enum Float8BlockScaleTensorFormat
9231127
* \brief Data format for an FP8 block-scaled tensor
9241128
*/

transformer_engine/pytorch/csrc/type_converters.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,104 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer)
156156
return ret;
157157
}
158158

159+
NVTEScalingMode ScalingModeFromQuantizer(py::handle quantizer) {
160+
auto *quantizer_ptr = quantizer.ptr();
161+
if (IsMXFP8Quantizers(quantizer_ptr)) {
162+
return NVTE_MXFP8_1D_SCALING;
163+
}
164+
if (IsNVFP4Quantizers(quantizer_ptr)) {
165+
return NVTE_NVFP4_1D_SCALING;
166+
}
167+
if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) {
168+
const int block_scaling_dim = quantizer.attr("block_scaling_dim").cast<int>();
169+
return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D;
170+
}
171+
return NVTE_DELAYED_TENSOR_SCALING;
172+
}
173+
174+
GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) {
175+
// Returns a GroupedTensorWrapper from a PyTorch GroupedTensor.
176+
const auto num_tensors = tensor.attr("num_tensors").cast<size_t>();
177+
const auto logical_shape = tensor.attr("logical_shape").cast<std::vector<size_t>>();
178+
179+
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
180+
if (!tensor.attr("quantizers").is_none()) {
181+
const auto quantizers = tensor.attr("quantizers").cast<py::list>();
182+
if (!quantizers.empty() && !quantizers[0].is_none()) {
183+
scaling_mode = ScalingModeFromQuantizer(quantizers[0]);
184+
}
185+
}
186+
187+
auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode);
188+
189+
// Rowwise data
190+
if (!tensor.attr("data").is_none()) {
191+
const auto &data = tensor.attr("data").cast<at::Tensor>();
192+
ret.set_rowwise_data(data.data_ptr(), GetTransformerEngineDType(data.scalar_type()),
193+
getTensorShape(data));
194+
}
195+
196+
// Columnwise data
197+
if (!tensor.attr("columnwise_data").is_none()) {
198+
const auto &data = tensor.attr("columnwise_data").cast<at::Tensor>();
199+
ret.set_columnwise_data(data.data_ptr(), GetTransformerEngineDType(data.scalar_type()),
200+
getTensorShape(data));
201+
}
202+
203+
// Scale
204+
if (!tensor.attr("scale").is_none()) {
205+
const auto &scale = tensor.attr("scale").cast<at::Tensor>();
206+
ret.set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()),
207+
getTensorShape(scale));
208+
}
209+
210+
// Amax
211+
if (!tensor.attr("amax").is_none()) {
212+
const auto &amax = tensor.attr("amax").cast<at::Tensor>();
213+
ret.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
214+
getTensorShape(amax));
215+
}
216+
if (!tensor.attr("columnwise_amax").is_none()) {
217+
const auto &amax = tensor.attr("columnwise_amax").cast<at::Tensor>();
218+
ret.set_columnwise_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
219+
getTensorShape(amax));
220+
}
221+
222+
// Scale inverse
223+
if (!tensor.attr("scale_inv").is_none()) {
224+
const auto &scale_inv = tensor.attr("scale_inv").cast<at::Tensor>();
225+
ret.set_rowwise_scale_inv(scale_inv.data_ptr(),
226+
GetTransformerEngineDType(scale_inv.scalar_type()),
227+
getTensorShape(scale_inv));
228+
}
229+
if (!tensor.attr("columnwise_scale_inv").is_none()) {
230+
const auto &scale_inv = tensor.attr("columnwise_scale_inv").cast<at::Tensor>();
231+
ret.set_columnwise_scale_inv(scale_inv.data_ptr(),
232+
GetTransformerEngineDType(scale_inv.scalar_type()),
233+
getTensorShape(scale_inv));
234+
}
235+
236+
// Shape metadata
237+
if (!tensor.attr("first_dims").is_none()) {
238+
const auto &first_dims = tensor.attr("first_dims").cast<at::Tensor>();
239+
ret.set_first_dims(first_dims.data_ptr(), GetTransformerEngineDType(first_dims.scalar_type()),
240+
getTensorShape(first_dims));
241+
}
242+
if (!tensor.attr("last_dims").is_none()) {
243+
const auto &last_dims = tensor.attr("last_dims").cast<at::Tensor>();
244+
ret.set_last_dims(last_dims.data_ptr(), GetTransformerEngineDType(last_dims.scalar_type()),
245+
getTensorShape(last_dims));
246+
}
247+
if (!tensor.attr("tensor_offsets").is_none()) {
248+
const auto &tensor_offsets = tensor.attr("tensor_offsets").cast<at::Tensor>();
249+
ret.set_tensor_offsets(tensor_offsets.data_ptr(),
250+
GetTransformerEngineDType(tensor_offsets.scalar_type()),
251+
getTensorShape(tensor_offsets));
252+
}
253+
254+
return ret;
255+
}
256+
159257
} // namespace detail
160258

161259
} // namespace transformer_engine::pytorch

0 commit comments

Comments
 (0)