@@ -919,6 +919,210 @@ class TensorWrapper {
919919 NVTETensor tensor_ = nullptr ;
920920};
921921
922+ /* ! \struct GroupedTensorWrapper
923+ * \brief C++ wrapper for the NVTEGroupedTensor class.
924+ */
925+ class GroupedTensorWrapper {
926+ public:
927+ /* ! \brief Constructs new GroupedTensorWrapper.
928+ *
929+ * Create a new TE grouped tensor with a given logical shape.
930+ * TE grouped tensors are just wrappers on top of raw data and do not
931+ * own memory.
932+ *
933+ * \param[in] num_tensors Number of tensors in the group (must be > 0).
934+ * \param[in] logical_shape Logical 2D shape of the grouped data.
935+ * \param[in] scaling_mode Tensor data format.
936+ */
937+ GroupedTensorWrapper (const size_t num_tensors, const NVTEShape &logical_shape,
938+ const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING)
939+ : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {}
940+
941+ /* ! \brief Constructs new GroupedTensorWrapper.
942+ *
943+ * Create a new TE grouped tensor with a given logical shape.
944+ *
945+ * \param[in] num_tensors Number of tensors in the group (must be > 0).
946+ * \param[in] logical_shape Logical 2D shape of the grouped data.
947+ * \param[in] scaling_mode Tensor data format.
948+ */
949+ GroupedTensorWrapper (const size_t num_tensors, const std::vector<size_t > &logical_shape,
950+ const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING)
951+ : GroupedTensorWrapper(num_tensors,
952+ nvte_make_shape (logical_shape.data(), logical_shape.size()),
953+ scaling_mode) {}
954+
955+ /* ! \brief GroupedTensorWrapper destructor. */
956+ ~GroupedTensorWrapper () { nvte_destroy_grouped_tensor (tensor_); }
957+
958+ GroupedTensorWrapper &operator =(const GroupedTensorWrapper &other) = delete ;
959+ GroupedTensorWrapper (const GroupedTensorWrapper &other) = delete ;
960+
961+ /* ! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */
962+ GroupedTensorWrapper (GroupedTensorWrapper &&other) {
963+ tensor_ = other.tensor_ ;
964+ other.tensor_ = nullptr ;
965+ }
966+
967+ /* ! \brief Assign the data from existing GroupedTensorWrapper. */
968+ GroupedTensorWrapper &operator =(GroupedTensorWrapper &&other) {
969+ if (this == &other) return *this ;
970+ nvte_destroy_grouped_tensor (tensor_);
971+ tensor_ = other.tensor_ ;
972+ other.tensor_ = nullptr ;
973+ return *this ;
974+ }
975+
976+ // Parameter setters
977+ template <typename ShapeType>
978+ GroupedTensorWrapper &set_parameter (const NVTEGroupedTensorParam param, void *dptr, DType type,
979+ const ShapeType &shape) noexcept {
980+ NVTEShape nvte_shape = this ->convertShape (shape);
981+ NVTEBasicTensor data = {dptr, static_cast <NVTEDType>(type), nvte_shape};
982+ nvte_set_grouped_tensor_param (&tensor_, param, &data);
983+ return *this ;
984+ }
985+
986+ template <typename ShapeType>
987+ GroupedTensorWrapper &set_rowwise_data (void *dptr, DType type, const ShapeType &shape) noexcept {
988+ return set_parameter (kNVTEGroupedRowwiseData , dptr, type, shape);
989+ }
990+
991+ template <typename ShapeType>
992+ GroupedTensorWrapper &set_columnwise_data (void *dptr, DType type,
993+ const ShapeType &shape) noexcept {
994+ return set_parameter (kNVTEGroupedColumnwiseData , dptr, type, shape);
995+ }
996+
997+ template <typename ShapeType>
998+ GroupedTensorWrapper &set_scale (void *dptr, DType type, const ShapeType &shape) noexcept {
999+ return set_parameter (kNVTEGroupedScale , dptr, type, shape);
1000+ }
1001+
1002+ template <typename ShapeType>
1003+ GroupedTensorWrapper &set_amax (void *dptr, DType type, const ShapeType &shape) noexcept {
1004+ return set_parameter (kNVTEGroupedAmax , dptr, type, shape);
1005+ }
1006+
1007+ template <typename ShapeType>
1008+ GroupedTensorWrapper &set_rowwise_scale_inv (void *dptr, DType type,
1009+ const ShapeType &shape) noexcept {
1010+ return set_parameter (kNVTEGroupedRowwiseScaleInv , dptr, type, shape);
1011+ }
1012+
1013+ template <typename ShapeType>
1014+ GroupedTensorWrapper &set_columnwise_scale_inv (void *dptr, DType type,
1015+ const ShapeType &shape) noexcept {
1016+ return set_parameter (kNVTEGroupedColumnwiseScaleInv , dptr, type, shape);
1017+ }
1018+
1019+ template <typename ShapeType>
1020+ GroupedTensorWrapper &set_columnwise_amax (void *dptr, DType type,
1021+ const ShapeType &shape) noexcept {
1022+ return set_parameter (kNVTEGroupedColumnwiseAmax , dptr, type, shape);
1023+ }
1024+
1025+ template <typename ShapeType>
1026+ GroupedTensorWrapper &set_first_dims (void *dptr, DType type, const ShapeType &shape) noexcept {
1027+ return set_parameter (kNVTEGroupedFirstDims , dptr, type, shape);
1028+ }
1029+
1030+ template <typename ShapeType>
1031+ GroupedTensorWrapper &set_last_dims (void *dptr, DType type, const ShapeType &shape) noexcept {
1032+ return set_parameter (kNVTEGroupedLastDims , dptr, type, shape);
1033+ }
1034+
1035+ template <typename ShapeType>
1036+ GroupedTensorWrapper &set_tensor_offsets (void *dptr, DType type,
1037+ const ShapeType &shape) noexcept {
1038+ return set_parameter (kNVTEGroupedTensorOffsets , dptr, type, shape);
1039+ }
1040+
1041+ // Parameter getters
1042+ NVTEBasicTensor get_parameter (const NVTEGroupedTensorParam param) const noexcept {
1043+ return nvte_get_grouped_tensor_param (tensor_, param);
1044+ }
1045+
1046+ NVTEBasicTensor get_rowwise_data () const noexcept {
1047+ return get_parameter (kNVTEGroupedRowwiseData );
1048+ }
1049+
1050+ NVTEBasicTensor get_columnwise_data () const noexcept {
1051+ return get_parameter (kNVTEGroupedColumnwiseData );
1052+ }
1053+
1054+ NVTEBasicTensor get_scale () const noexcept { return get_parameter (kNVTEGroupedScale ); }
1055+
1056+ NVTEBasicTensor get_amax () const noexcept { return get_parameter (kNVTEGroupedAmax ); }
1057+
1058+ NVTEBasicTensor get_rowwise_scale_inv () const noexcept {
1059+ return get_parameter (kNVTEGroupedRowwiseScaleInv );
1060+ }
1061+
1062+ NVTEBasicTensor get_columnwise_scale_inv () const noexcept {
1063+ return get_parameter (kNVTEGroupedColumnwiseScaleInv );
1064+ }
1065+
1066+ NVTEBasicTensor get_columnwise_amax () const noexcept {
1067+ return get_parameter (kNVTEGroupedColumnwiseAmax );
1068+ }
1069+
1070+ NVTEBasicTensor get_first_dims () const noexcept { return get_parameter (kNVTEGroupedFirstDims ); }
1071+
1072+ NVTEBasicTensor get_last_dims () const noexcept { return get_parameter (kNVTEGroupedLastDims ); }
1073+
1074+ NVTEBasicTensor get_tensor_offsets () const noexcept {
1075+ return get_parameter (kNVTEGroupedTensorOffsets );
1076+ }
1077+
1078+ /* ! \brief Get an underlying NVTEGroupedTensor.
1079+ *
1080+ * \return NVTEGroupedTensor held by this GroupedTensorWrapper.
1081+ */
1082+ NVTEGroupedTensor data () const noexcept { return tensor_; }
1083+
1084+ /* ! \brief Get the number of tensors in this GroupedTensorWrapper. */
1085+ size_t num_tensors () const noexcept {
1086+ if (tensor_ == nullptr ) return 0 ;
1087+ return nvte_grouped_tensor_num_tensors (tensor_);
1088+ }
1089+
1090+ /* ! \brief Get the data type of this GroupedTensorWrapper. */
1091+ DType dtype () const noexcept {
1092+ if (tensor_ == nullptr ) return DType::kNumTypes ;
1093+ return static_cast <DType>(nvte_grouped_tensor_type (tensor_));
1094+ }
1095+
1096+ /* ! \brief Get a scaling mode of the grouped tensor. */
1097+ NVTEScalingMode scaling_mode () const noexcept {
1098+ if (tensor_ == nullptr ) return NVTE_DELAYED_TENSOR_SCALING;
1099+ return nvte_grouped_tensor_scaling_mode (tensor_);
1100+ }
1101+
1102+ /* ! \brief Get the logical shape of this GroupedTensorWrapper. */
1103+ const NVTEShape logical_shape () const noexcept {
1104+ if (tensor_ == nullptr ) {
1105+ return emptyShape;
1106+ }
1107+ return nvte_get_grouped_tensor_logical_shape (tensor_);
1108+ }
1109+
1110+ static constexpr size_t defaultData = 1 ;
1111+ static constexpr NVTEShape defaultShape = {
1112+ {defaultData, 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 }, 1 };
1113+ static constexpr NVTEShape emptyShape = {{0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 }, 1 };
1114+
1115+ private:
1116+ NVTEShape convertShape (const NVTEShape &s) { return s; }
1117+
1118+ NVTEShape convertShape (const std::vector<size_t > &s) {
1119+ return nvte_make_shape (s.data (), s.size ());
1120+ }
1121+
1122+ /* ! \brief Wrapped NVTEGroupedTensor. */
1123+ NVTEGroupedTensor tensor_ = nullptr ;
1124+ };
1125+
9221126/* ! \enum Float8BlockScaleTensorFormat
9231127 * \brief Data format for an FP8 block-scaled tensor
9241128 */
0 commit comments