@@ -36,8 +36,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
36
36
// PagedAttention V2.
37
37
ops.def (
38
38
" paged_attention_v2("
39
- " Tensor! out, Tensor exp_sums, Tensor max_logits,"
40
- " Tensor tmp_out, Tensor query, Tensor key_cache,"
39
+ " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
40
+ " Tensor! tmp_out, Tensor query, Tensor key_cache,"
41
41
" Tensor value_cache, int num_kv_heads, float scale,"
42
42
" Tensor block_tables, Tensor seq_lens, int block_size,"
43
43
" int max_seq_len, Tensor? alibi_slopes,"
@@ -73,7 +73,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
73
73
ops.impl (" gelu_quick" , torch::kCUDA , &gelu_quick);
74
74
75
75
// prepare_inputs advance_step
76
- ops.def (" advance_step" , &advance_step);
76
+ ops.def (
77
+ " advance_step(int num_seqs, int num_queries, int block_size, "
78
+ " Tensor! input_tokens, Tensor sampled_token_ids, "
79
+ " Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
80
+ " Tensor block_tables) -> ()" );
77
81
ops.impl (" advance_step" , torch::kCUDA , &advance_step);
78
82
79
83
// Layernorm
@@ -110,27 +114,56 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
110
114
// Quantization ops
111
115
#ifndef USE_ROCM
112
116
// Quantized GEMM for AQLM.
113
- ops.def (" aqlm_gemm" , &aqlm_gemm);
117
+ ops.def (
118
+ " aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
119
+ " Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
120
+ " -> Tensor" );
114
121
ops.impl (" aqlm_gemm" , torch::kCUDA , &aqlm_gemm);
115
122
116
123
// Decompression method for AQLM.
117
- ops.def (" aqlm_dequant" , &aqlm_dequant);
124
+ ops.def (
125
+ " aqlm_dequant(Tensor codes, Tensor codebooks, "
126
+ " int[] codebook_partition_sizes) -> Tensor" );
118
127
ops.impl (" aqlm_dequant" , torch::kCUDA , &aqlm_dequant);
119
128
120
129
// Quantized GEMM for AWQ.
121
- ops.def (" awq_gemm" , &awq_gemm);
130
+ ops.def (
131
+ " awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
132
+ " Tensor _zeros, int split_k_iters) -> Tensor" );
122
133
ops.impl (" awq_gemm" , torch::kCUDA , &awq_gemm);
123
134
124
135
// Dequantization for AWQ.
125
- ops.def (" awq_dequantize" , &awq_dequantize);
136
+ ops.def (
137
+ " awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
138
+ " Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor" );
126
139
ops.impl (" awq_dequantize" , torch::kCUDA , &awq_dequantize);
127
140
141
+ // Note about marlin kernel 'workspace' arguments:
142
+ // Technically these should be mutable since they are modified by the kernel.
143
+ // But since they are set back to zero once the kernel is finished we can
144
+ // hand wave and say that they have no net effect.
145
+ //
146
+ // The reason to mark 'workspace' as immutable is so that they don't interfere
147
+ // with using ScalarType arguments in the ops. If they are marked as mutable,
148
+ // pytorch throws an assert in
149
+ // 'torch._higher_order_ops._register_effectful_op' that prevents these
150
+ // kernels from being torch.compile'd.
151
+ // See the following document for more info on custom types and ops that use
152
+ // custom types:
153
+ // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
154
+
128
155
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
129
- ops.def (" marlin_gemm" , &marlin_gemm);
156
+ ops.def (
157
+ " marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
158
+ " Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor" );
130
159
ops.impl (" marlin_gemm" , torch::kCUDA , &marlin_gemm);
131
160
132
161
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
133
- ops.def (" gptq_marlin_24_gemm" , &gptq_marlin_24_gemm);
162
+ ops.def (
163
+ " gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
164
+ " Tensor b_scales, Tensor workspace, "
165
+ " __torch__.torch.classes._core_C.ScalarType b_q_type, "
166
+ " int size_m, int size_n, int size_k) -> Tensor" );
134
167
ops.impl (" gptq_marlin_24_gemm" , torch::kCUDA , &gptq_marlin_24_gemm);
135
168
136
169
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
@@ -149,35 +182,55 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
149
182
ops.impl (" machete_prepack_B" , torch::kCUDA , &machete::prepack_B);
150
183
151
184
// gptq_marlin Optimized Quantized GEMM for GPTQ.
152
- ops.def (" gptq_marlin_gemm" , &gptq_marlin_gemm);
185
+ ops.def (
186
+ " gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
187
+ " Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
188
+ " __torch__.torch.classes._core_C.ScalarType b_q_type, "
189
+ " int size_m, int size_n, int size_k, bool is_k_full, "
190
+ " bool has_zp, bool use_fp32_reduce) -> Tensor" );
153
191
ops.impl (" gptq_marlin_gemm" , torch::kCUDA , &gptq_marlin_gemm);
154
192
155
193
// gptq_marlin repack from GPTQ.
156
- ops.def (" gptq_marlin_repack" , &gptq_marlin_repack);
194
+ ops.def (
195
+ " gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
196
+ " SymInt size_k, SymInt size_n, int num_bits) -> Tensor" );
157
197
ops.impl (" gptq_marlin_repack" , torch::kCUDA , &gptq_marlin_repack);
198
+ ops.impl (" gptq_marlin_repack" , torch::kMeta , &gptq_marlin_repack_meta);
158
199
159
200
// awq_marlin repack from AWQ.
160
- ops.def (" awq_marlin_repack" , &awq_marlin_repack);
201
+ ops.def (
202
+ " awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
203
+ " SymInt size_n, int num_bits) -> Tensor" );
161
204
ops.impl (" awq_marlin_repack" , torch::kCUDA , &awq_marlin_repack);
205
+ ops.impl (" awq_marlin_repack" , torch::kMeta , &awq_marlin_repack_meta);
162
206
163
207
// Dequantization for GGML.
164
- ops.def (" ggml_dequantize" , &ggml_dequantize );
208
+ ops.def (" ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor " );
165
209
ops.impl (" ggml_dequantize" , torch::kCUDA , &ggml_dequantize);
166
210
167
211
// mmvq kernel for GGML.
168
- ops.def (" ggml_mul_mat_vec_a8" , &ggml_mul_mat_vec_a8);
212
+ ops.def (
213
+ " ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) "
214
+ " -> Tensor" );
169
215
ops.impl (" ggml_mul_mat_vec_a8" , torch::kCUDA , &ggml_mul_mat_vec_a8);
170
216
171
217
// mmq kernel for GGML.
172
- ops.def (" ggml_mul_mat_a8" , &ggml_mul_mat_a8 );
218
+ ops.def (" ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor " );
173
219
ops.impl (" ggml_mul_mat_a8" , torch::kCUDA , &ggml_mul_mat_a8);
174
220
175
221
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
176
- ops.def (" fp8_marlin_gemm" , &fp8_marlin_gemm);
222
+ ops.def (
223
+ " fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
224
+ " Tensor! workspace, int num_bits, int size_m, int size_n, "
225
+ " int size_k) -> Tensor" );
177
226
ops.impl (" fp8_marlin_gemm" , torch::kCUDA , &fp8_marlin_gemm);
178
227
179
228
// marlin_qqq_gemm for QQQ.
180
- ops.def (" marlin_qqq_gemm" , &marlin_qqq_gemm);
229
+ ops.def (
230
+ " marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
231
+ " Tensor s_tok, Tensor s_ch, Tensor s_group, "
232
+ " Tensor! workspace, int size_m, int size_n, "
233
+ " int size_k) -> Tensor" );
181
234
ops.impl (" marlin_qqq_gemm" , torch::kCUDA , &marlin_qqq_gemm);
182
235
183
236
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
@@ -199,16 +252,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
199
252
200
253
// Check if cutlass scaled_mm is supported for CUDA devices of the given
201
254
// capability
202
- ops.def (" cutlass_scaled_mm_supports_fp8" , &cutlass_scaled_mm_supports_fp8 );
203
- ops.impl (" cutlass_scaled_mm_supports_fp8" , torch:: kCUDA ,
204
- &cutlass_scaled_mm_supports_fp8);
255
+ ops.def (" cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool " );
256
+ ops.impl (" cutlass_scaled_mm_supports_fp8" , &cutlass_scaled_mm_supports_fp8);
257
+
205
258
// Mamba selective scan kernel
206
259
ops.def (
207
260
" selective_scan_fwd(Tensor! u, Tensor! delta,"
208
261
" Tensor! A, Tensor! B, Tensor! C,"
209
262
" Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
210
263
" bool delta_softplus,"
211
- " Tensor? index_, Tensor? x) -> Tensor[]" );
264
+ " Tensor? index_, Tensor(a! -> *) ? x) -> Tensor(a) []" );
212
265
ops.impl (" selective_scan_fwd" , torch::kCUDA , &selective_scan_fwd);
213
266
214
267
ops.def (
@@ -230,7 +283,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
230
283
#endif
231
284
232
285
// Quantized GEMM for GPTQ.
233
- ops.def (" gptq_gemm" , &gptq_gemm);
286
+ // Note: even though the C++ inferred schema is correct for this op, it seems
287
+ // to prevent the meta function registry.
288
+ ops.def (
289
+ " gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
290
+ " Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
291
+ " -> Tensor" );
234
292
ops.impl (" gptq_gemm" , torch::kCUDA , &gptq_gemm);
235
293
236
294
// Post processing for GPTQ.
@@ -250,8 +308,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
250
308
251
309
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
252
310
ops.def (
253
- " dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
254
- " scale, Tensor? scale_ub) -> "
311
+ " dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, "
312
+ " Tensor! scale, Tensor? scale_ub) -> "
255
313
" ()" );
256
314
ops.impl (" dynamic_per_token_scaled_fp8_quant" , torch::kCUDA ,
257
315
&dynamic_per_token_scaled_fp8_quant);
@@ -288,8 +346,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
288
346
289
347
// Copy the cache blocks from src to dst.
290
348
cache_ops.def (
291
- " copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
292
- " block_mapping) -> ()" );
349
+ " copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
350
+ " Tensor block_mapping) -> ()" );
293
351
cache_ops.impl (" copy_blocks" , torch::kCUDA , ©_blocks);
294
352
295
353
// Reshape the key and value tensors and cache them.
@@ -314,33 +372,37 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
314
372
315
373
// Convert the key and value cache to fp8 data type.
316
374
cache_ops.def (
317
- " convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
318
- " kv_cache_dtype) -> ()" );
375
+ " convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
376
+ " str kv_cache_dtype) -> ()" );
319
377
cache_ops.impl (" convert_fp8" , torch::kCUDA , &convert_fp8);
320
378
}
321
379
322
380
TORCH_LIBRARY_EXPAND (CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
323
381
// Cuda utils
324
382
325
383
// Gets the specified device attribute.
326
- cuda_utils.def (" get_device_attribute" , &get_device_attribute );
327
- cuda_utils.impl (" get_device_attribute" , torch:: kCUDA , &get_device_attribute);
384
+ cuda_utils.def (" get_device_attribute(int attribute, int device_id) -> int " );
385
+ cuda_utils.impl (" get_device_attribute" , &get_device_attribute);
328
386
329
387
// Gets the maximum shared memory per block device attribute.
330
- cuda_utils.def (" get_max_shared_memory_per_block_device_attribute " ,
331
- &get_max_shared_memory_per_block_device_attribute );
388
+ cuda_utils.def (
389
+ " get_max_shared_memory_per_block_device_attribute(int device_id) -> int " );
332
390
cuda_utils.impl (" get_max_shared_memory_per_block_device_attribute" ,
333
- torch::kCUDA ,
334
391
&get_max_shared_memory_per_block_device_attribute);
335
392
}
336
393
337
394
#ifndef USE_ROCM
338
395
TORCH_LIBRARY_EXPAND (CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
339
396
// Custom all-reduce kernels
340
- custom_ar.def (" init_custom_ar" , &init_custom_ar);
397
+ custom_ar.def (
398
+ " init_custom_ar(Tensor meta, Tensor rank_data, "
399
+ " str[] handles, int[] offsets, int rank, "
400
+ " bool full_nvlink) -> int" );
341
401
custom_ar.impl (" init_custom_ar" , torch::kCUDA , &init_custom_ar);
342
402
343
- custom_ar.def (" should_custom_ar" , &should_custom_ar);
403
+ custom_ar.def (
404
+ " should_custom_ar(Tensor inp, int max_size, int world_size, "
405
+ " bool full_nvlink) -> bool" );
344
406
custom_ar.impl (" should_custom_ar" , torch::kCUDA , &should_custom_ar);
345
407
346
408
custom_ar.def (" all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()" );
@@ -352,21 +414,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
352
414
custom_ar.impl (" all_reduce_unreg" , torch::kCUDA , &all_reduce_unreg);
353
415
354
416
custom_ar.def (" dispose" , &dispose);
355
- custom_ar.impl (" dispose" , torch::kCPU , &dispose);
356
-
357
417
custom_ar.def (" meta_size" , &meta_size);
358
- custom_ar.impl (" meta_size" , torch::kCPU , &meta_size);
359
418
360
- custom_ar.def (" register_buffer" , ®ister_buffer);
419
+ custom_ar.def (
420
+ " register_buffer(int fa, Tensor t, str[] handles, "
421
+ " int[] offsets) -> ()" );
361
422
custom_ar.impl (" register_buffer" , torch::kCUDA , ®ister_buffer);
362
423
363
424
custom_ar.def (" get_graph_buffer_ipc_meta" , &get_graph_buffer_ipc_meta);
364
- custom_ar.impl (" get_graph_buffer_ipc_meta" , torch::kCPU ,
365
- &get_graph_buffer_ipc_meta);
366
-
367
425
custom_ar.def (" register_graph_buffers" , ®ister_graph_buffers);
368
- custom_ar.impl (" register_graph_buffers" , torch::kCPU ,
369
- ®ister_graph_buffers);
370
426
}
371
427
#endif
372
428
0 commit comments