@@ -84,6 +84,53 @@ void init_triton_hopper_passes(py::module &&m) {
84
84
mlir::createNVGPUWarpSpecialization, int , bool );
85
85
}
86
86
87
+ static void checkMatmulConstraints (const std::string &A_dtype,
88
+ const std::string &B_dtype,
89
+ const std::string &C_dtype,
90
+ const std::vector<int > &A_shape,
91
+ const std::vector<int > &B_shape,
92
+ const std::vector<int > &C_shape) {
93
+ if (A_dtype != B_dtype || A_dtype != C_dtype) {
94
+ throw std::runtime_error (" Data types do not match." );
95
+ }
96
+ if (A_dtype != " torch.float8_e4m3fn" && A_dtype != " torch.float16" ) {
97
+ throw std::runtime_error (" Unsupported data type." );
98
+ }
99
+
100
+ if (A_shape.size () != 2 || B_shape.size () != 2 || C_shape.size () != 2 ) {
101
+ throw std::runtime_error (" Only 2D matrices are supported." );
102
+ }
103
+
104
+ int k = A_shape[1 ];
105
+ if (k != B_shape[1 ]) {
106
+ throw std::runtime_error (
107
+ " Matrix dimensions do not match. A is [" + std::to_string (A_shape[0 ]) +
108
+ " , " + std::to_string (A_shape[1 ]) + " ], B is [" +
109
+ std::to_string (B_shape[0 ]) + " , " + std::to_string (B_shape[1 ]) +
110
+ " ]. Expected A.shape[1] == B.shape[1]. Note "
111
+ " that B needs to be transposed." );
112
+ }
113
+
114
+ int m = A_shape[0 ];
115
+ if (m != C_shape[0 ]) {
116
+ throw std::runtime_error (
117
+ " Matrix dimensions do not match. A is [" + std::to_string (A_shape[0 ]) +
118
+ " , " + std::to_string (A_shape[1 ]) + " ], C is [" +
119
+ std::to_string (C_shape[0 ]) + " , " + std::to_string (C_shape[1 ]) +
120
+ " ]. Expected A.shape[0] == C.shape[0]." );
121
+ }
122
+
123
+ int n = B_shape[0 ];
124
+ if (n != C_shape[1 ]) {
125
+ throw std::runtime_error (
126
+ " Matrix dimensions do not match. B is [" + std::to_string (B_shape[0 ]) +
127
+ " , " + std::to_string (B_shape[1 ]) + " ], C is [" +
128
+ std::to_string (C_shape[0 ]) + " , " + std::to_string (C_shape[1 ]) +
129
+ " ]. Expected B.shape[0] == C.shape[1]. Note "
130
+ " that B needs to be transposed." );
131
+ }
132
+ }
133
+
87
134
void init_triton_nvidia (py::module &&m) {
88
135
auto passes = m.def_submodule (" passes" );
89
136
init_triton_nvidia_passes_nvws (passes.def_submodule (" nvws" ));
@@ -155,22 +202,64 @@ void init_triton_nvidia(py::module &&m) {
155
202
workspace.attr (" element_size" )().cast <size_t >();
156
203
return new CublasLtInstance (wrk_ptr, wrk_size);
157
204
}))
158
- .def (" matmul" , [](CublasLtInstance &self, py::object &A, py::object &B,
159
- py::object &C) {
205
+ .def (" matmul" ,
206
+ [](CublasLtInstance &self, py::object &A, py::object &B,
207
+ py::object &C) {
208
+ auto A_ptr = A.attr (" data_ptr" )().cast <uint64_t >();
209
+ auto B_ptr = B.attr (" data_ptr" )().cast <uint64_t >();
210
+ auto C_ptr = C.attr (" data_ptr" )().cast <uint64_t >();
211
+
212
+ auto A_shape = A.attr (" shape" ).cast <std::vector<int >>();
213
+ auto B_shape = B.attr (" shape" ).cast <std::vector<int >>();
214
+ auto C_shape = C.attr (" shape" ).cast <std::vector<int >>();
215
+
216
+ auto A_dtype =
217
+ A.attr (" dtype" ).attr (" __str__" )().cast <std::string>();
218
+ auto B_dtype =
219
+ B.attr (" dtype" ).attr (" __str__" )().cast <std::string>();
220
+ auto C_dtype =
221
+ C.attr (" dtype" ).attr (" __str__" )().cast <std::string>();
222
+
223
+ checkMatmulConstraints (A_dtype, B_dtype, C_dtype, A_shape, B_shape,
224
+ C_shape);
225
+
226
+ std::string dtype_str =
227
+ A_dtype.substr (A_dtype.find_last_of (' .' ) + 1 );
228
+ cudaDataType_t dtype;
229
+ if (dtype_str == " float8_e4m3fn" ) {
230
+ dtype = CUDA_R_8F_E4M3;
231
+ } else if (dtype_str == " float16" ) {
232
+ dtype = CUDA_R_16F;
233
+ }
234
+
235
+ self.matmul (A_shape[0 ], B_shape[0 ], A_shape[1 ], A_ptr, B_ptr,
236
+ C_ptr, dtype);
237
+ })
238
+ .def (" gemm" , [](CublasLtInstance &self, py::object &A, py::object &B,
239
+ py::object &C, py::object &D, float alpha, float beta) {
160
240
auto A_ptr = A.attr (" data_ptr" )().cast <uint64_t >();
161
241
auto B_ptr = B.attr (" data_ptr" )().cast <uint64_t >();
162
242
auto C_ptr = C.attr (" data_ptr" )().cast <uint64_t >();
243
+ auto D_ptr = D.attr (" data_ptr" )().cast <uint64_t >();
163
244
164
245
auto A_shape = A.attr (" shape" ).cast <std::vector<int >>();
165
246
auto B_shape = B.attr (" shape" ).cast <std::vector<int >>();
166
247
auto C_shape = C.attr (" shape" ).cast <std::vector<int >>();
248
+ auto D_shape = D.attr (" shape" ).cast <std::vector<int >>();
167
249
168
250
auto A_dtype = A.attr (" dtype" ).attr (" __str__" )().cast <std::string>();
169
251
auto B_dtype = B.attr (" dtype" ).attr (" __str__" )().cast <std::string>();
170
252
auto C_dtype = C.attr (" dtype" ).attr (" __str__" )().cast <std::string>();
253
+ auto D_dtype = D.attr (" dtype" ).attr (" __str__" )().cast <std::string>();
171
254
172
- assert (A_dtype == B_dtype && A_dtype == C_dtype);
173
- assert (A_dtype == " torch.float8_e4m3fn" || A_dtype == " torch.float16" );
255
+ checkMatmulConstraints (A_dtype, B_dtype, D_dtype, A_shape, B_shape,
256
+ D_shape);
257
+ if (C_dtype != " torch.float16" ) {
258
+ throw std::runtime_error (" C dtype must be float16, got " + C_dtype);
259
+ }
260
+ if (C_shape != D_shape) {
261
+ throw std::runtime_error (" C and D shapes must match" );
262
+ }
174
263
175
264
std::string dtype_str = A_dtype.substr (A_dtype.find_last_of (' .' ) + 1 );
176
265
cudaDataType_t dtype;
@@ -180,43 +269,7 @@ void init_triton_nvidia(py::module &&m) {
180
269
dtype = CUDA_R_16F;
181
270
}
182
271
183
- if (A_shape.size () != 2 || B_shape.size () != 2 || C_shape.size () != 2 ) {
184
- throw std::runtime_error (" Only 2D matrices are supported." );
185
- }
186
-
187
- int k = A_shape[1 ];
188
- if (k != B_shape[1 ]) {
189
- throw std::runtime_error (" Matrix dimensions do not match. A is [" +
190
- std::to_string (A_shape[0 ]) + " , " +
191
- std::to_string (A_shape[1 ]) + " ], B is [" +
192
- std::to_string (B_shape[0 ]) + " , " +
193
- std::to_string (B_shape[1 ]) +
194
- " ]. Expected A.shape[1] == B.shape[1]. Note "
195
- " that B needs to be transposed." );
196
- }
197
-
198
- int m = A_shape[0 ];
199
- if (m != C_shape[0 ]) {
200
- throw std::runtime_error (" Matrix dimensions do not match. A is [" +
201
- std::to_string (A_shape[0 ]) + " , " +
202
- std::to_string (A_shape[1 ]) + " ], C is [" +
203
- std::to_string (C_shape[0 ]) + " , " +
204
- std::to_string (C_shape[1 ]) +
205
- " ]. Expected A.shape[0] == C.shape[0]." );
206
- }
207
-
208
- int n = B_shape[0 ];
209
- if (n != C_shape[1 ]) {
210
- throw std::runtime_error (" Matrix dimensions do not match. B is [" +
211
- std::to_string (B_shape[0 ]) + " , " +
212
- std::to_string (B_shape[1 ]) + " ], C is [" +
213
- std::to_string (C_shape[0 ]) + " , " +
214
- std::to_string (C_shape[1 ]) +
215
- " ]. Expected B.shape[0] == C.shape[1]. Note "
216
- " that B needs to be transposed." );
217
- }
218
-
219
- self.matmul (A_shape[0 ], B_shape[0 ], A_shape[1 ], A_ptr, B_ptr, C_ptr,
220
- dtype);
272
+ self.gemm (A_shape[0 ], B_shape[0 ], A_shape[1 ], A_ptr, B_ptr, C_ptr,
273
+ D_ptr, dtype, alpha, beta);
221
274
});
222
275
}
0 commit comments