-
Notifications
You must be signed in to change notification settings - Fork 112
Expand file tree
/
Copy pathmatmul_cuda.h
More file actions
1252 lines (1114 loc) · 50 KB
/
matmul_cuda.h
File metadata and controls
1252 lines (1114 loc) · 50 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
////////////////////////////////////////////////////////////////////////////////
// BSD 3-Clause License
//
// Copyright (c) 2021, NVIDIA Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/////////////////////////////////////////////////////////////////////////////////
#pragma once
#include <cublasLt.h>
#ifdef MATX_ENABLE_CUTLASS
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_batched.h"
#endif
#include <cstdio>
#include <numeric>
#include "cublas_v2.h"
#include "matx/core/cache.h"
#include "matx/core/error.h"
#include "matx/core/nvtx.h"
#include "matx/core/tensor.h"
#include "matx/transforms/matmul/matmul_common.h"
namespace matx {
/**
* Defines a provider type for a GEMM. The provider is directly tied to the
* underlying library used for the gemm, and certain providers provide
* capabilities that others may not have.
*/
typedef enum {
PROVIDER_TYPE_CUTLASS = 0, ///< CUTLASS library
PROVIDER_TYPE_CUBLASLT = 1, ///< cuBLASLt library
PROVIDER_TYPE_AUTO, ///< Automatically select
PROVIDER_TYPE_SENTINEL ///< Sentinel value. Do not use
} MatMulCUDAProvider_t;
namespace detail {
// Configurable tensor rank threshold for single batch operation
static constexpr int MATMUL_BATCH_RANK_THRESHOLD = 4;
typedef enum {
MEM_ORDER_ROW_MAJOR = 0,
MEM_ORDER_COL_MAJOR = 1,
} MemOrder_t;
template <typename OpA, typename OpB, typename OpC, MatMulCUDAProvider_t PROV = PROVIDER_TYPE_CUBLASLT>
constexpr bool CompatibleGemmCUDATypes() {
if constexpr (!std::is_same_v<typename OpA::value_type, typename OpB::value_type> &&
!std::is_same_v<typename OpB::value_type, typename OpC::value_type> &&
!std::is_same_v<typename OpA::value_type, typename OpC::value_type>) {
return false;
}
if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) {
if constexpr (std::is_same_v<typename OpA::value_type, typename OpB::value_type> &&
std::is_same_v<typename OpB::value_type, typename OpC::value_type>) {
// List of accepted types when A/B/C match
return std::is_same_v<typename OpA::value_type, matxFp16> ||
std::is_same_v<typename OpA::value_type, matxBf16> ||
std::is_same_v<typename OpA::value_type, float> ||
std::is_same_v<typename OpA::value_type, double> ||
std::is_same_v<typename OpA::value_type, cuda::std::complex<float>> ||
std::is_same_v<typename OpA::value_type, cuda::std::complex<double>> ||
std::is_same_v<typename OpA::value_type, int8_t> ||
std::is_same_v<typename OpA::value_type, matxFp16Complex> ||
std::is_same_v<typename OpA::value_type, matxBf16Complex>;
}
// Accumulator type different from A/B
else if constexpr ( std::is_same_v<typename OpA::value_type, typename OpB::value_type> &&
!std::is_same_v<typename OpB::value_type, typename OpC::value_type>) {
return (std::is_same_v<typename OpA::value_type, int8_t> && std::is_same_v<typename OpC::value_type, int32_t>) ||
(std::is_same_v<typename OpA::value_type, int8_t> && std::is_same_v<typename OpC::value_type, float>) ||
(std::is_same_v<typename OpA::value_type, matxBf16> && std::is_same_v<typename OpC::value_type, float>) ||
(std::is_same_v<typename OpA::value_type, matxFp16> && std::is_same_v<typename OpC::value_type, float>) ||
(std::is_same_v<typename OpA::value_type, int8_t> && std::is_same_v<typename OpC::value_type, float>);
}
}
return true;
}
/**
* Parameters needed to execute a CUDA GEMM. For the most part, these are very
* similar to that of a standard GEMM call
*/
struct MatMulCUDAParams_t {
index_t a_rows = 0;
index_t a_cols = 0;
index_t b_rows = 0;
index_t b_cols = 0;
index_t c_rows = 0;
index_t c_cols = 0;
index_t m = 0;
index_t n = 0;
index_t k = 0;
index_t lda;
index_t ldb;
index_t ldc;
int rank;
int32_t batch; // Must be int32_t for cuBLASLt
index_t astride; // batch stride
index_t bstride; // batch stride
index_t cstride; // batch stride
MatMulCUDAProvider_t prov;
cudaStream_t stream;
MatXDataType_t dtype;
cublasOperation_t opA;
cublasOperation_t opB;
};
template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB,
MatMulCUDAProvider_t PROV = PROVIDER_TYPE_CUBLASLT>
class MatMulCUDAHandle_t {
public:
using T1 = typename TensorTypeC::value_type;
using T2 = typename TensorTypeA::value_type;
using T3 = typename TensorTypeB::value_type;
static constexpr int RANK = TensorTypeC::Rank();
// We allow a batch stride of 0 on one of the tensors, so only make sure C's rank matches one of them
static_assert(TensorTypeC::Rank() == TensorTypeB::Rank() || TensorTypeB::Rank() == 2);
static_assert(TensorTypeC::Rank() == TensorTypeA::Rank() || TensorTypeA::Rank() == 2);
/**
* Construct a GEMM handle
*
* Creates a GEMM handle for the view shapes and provider type given. The view
* shapes are used to create the underlying metadata used for the GEMM, so a
* handle should only be used for views of identical sizes. The provider
* chooses the underlying library used to perform the GEMM. Certain providers
* have more features than others and may perform differently than others. At
* the moment, it is recommended to try different providers for a given matrix
* size until the optimal provider is found. Different providers may also be
* used by creating multiple handles.
*
* @tparam T1
* Type of C matrix
* @tparam T2
* Type of A matrix
* @tparam T3
* Type of B matrix
* @tparam PROV
* Provider type chosen from MatMulCUDAProvider_t type
*
* @param c
* C matrix view
* @param a
* A matrix view
* @param b
* B matrix view
*
*/
MatMulCUDAHandle_t(TensorTypeC &c, const TensorTypeA &a,
const TensorTypeB &b)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
static_assert(RANK >= 2);
MATX_ASSERT(a.Size(TensorTypeA::Rank() - 1) == b.Size(TensorTypeB::Rank() - 2), matxInvalidSize);
MATX_ASSERT(c.Size(RANK - 1) == b.Size(TensorTypeB::Rank() - 1), matxInvalidSize);
MATX_ASSERT(c.Size(RANK - 2) == a.Size(TensorTypeA::Rank() - 2), matxInvalidSize);
// Ensure batch dimensions are equal
for (int i = 0; i < RANK - 2; i++) {
if constexpr (RANK == TensorTypeA::Rank()) {
MATX_ASSERT(a.Size(i) == c.Size(i), matxInvalidSize);
}
if constexpr (RANK == TensorTypeB::Rank()) {
MATX_ASSERT(b.Size(i) == c.Size(i), matxInvalidSize);
}
}
// This must come before the things below to properly set class parameters
params_ = GetGemmParams(c, a, b);
if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
// Thus, try to detect if we are running on Hopper or newer and use a 32 MiB workspace
// if so. Otherwise, default to 4 MiB, which still works on Hopper+.
constexpr size_t MiB = 1024*1024;
workspaceSize = detail::IsHopperOrAbove() ? 32*MiB : 4*MiB;
// Workspace buffer
matxAlloc((void **)&workspace, workspaceSize, MATX_DEVICE_MEMORY);
ConfigureCublasLt();
}
}
template <typename InputType>
static void SetAlphaBeta([[maybe_unused]] char *const palpha,
[[maybe_unused]] char *const pbeta,
[[maybe_unused]] float const alpha,
[[maybe_unused]] float const beta)
{
// For now we don't give much flexibility on compute type/alpha
if constexpr (std::is_same_v<InputType, cuda::std::complex<float>> ||
is_complex_half_v<InputType>) {
cuComplex *calpha = reinterpret_cast<cuComplex *>(palpha);
cuComplex *cbeta = reinterpret_cast<cuComplex *>(pbeta);
*calpha = {alpha, 0};
*cbeta = {beta, 0};
}
else if constexpr (std::is_same_v<InputType, cuda::std::complex<double>>) {
cuDoubleComplex *dalpha = reinterpret_cast<cuDoubleComplex *>(palpha);
cuDoubleComplex *dbeta = reinterpret_cast<cuDoubleComplex *>(pbeta);
*dalpha = {alpha, 0};
*dbeta = {beta, 0};
}
else if constexpr (std::is_same_v<InputType, double>) {
double *dalpha = reinterpret_cast<double *>(palpha);
double *dbeta = reinterpret_cast<double *>(pbeta);
*dalpha = alpha;
*dbeta = beta;
}
else if constexpr (is_matx_half_v<InputType> ||
std::is_same_v<InputType, float>) {
float *talpha = reinterpret_cast<float *>(palpha);
float *tbeta = reinterpret_cast<float *>(pbeta);
*talpha = alpha;
*tbeta = beta;
}
else {
MATX_THROW(matxInvalidType, "Invalid type when deducing alpha/beta");
}
}
static detail::MatMulCUDAParams_t GetGemmParams(TensorTypeC &c, const TensorTypeA &a,
const TensorTypeB &b)
{
/* If a user passes in a tensor where the last two dimensions are transposed we retain
the original size parameters, but tell the underlying libraries that the tensors are
in column-major ordering. The exception to this is when a transposed half-precision
complex type is used. In that case we have to make a temporary copy of the tensor to
put the data in planar format for the libraries. Since we now use the temporary tensor
as input to the GEMM, the data is no longer transposed in memory and we simply use
the same memory layout as a non-transposed real matrix would use.
*/
detail::MatMulCUDAParams_t params;
params.dtype = TypeToInt<T1>();
params.prov = PROV;
params.rank = c.Rank();
// Batches
params.batch = 1;
params.astride = 0;
params.bstride = 0;
params.cstride = 0;
// If we have a tensor with rank > 2, treat all dimensions except last 2 as batch dimensions
if constexpr (RANK > 2) {
auto c_shape = c.Shape();
params.batch = static_cast<int>(std::accumulate(
c_shape.begin() + std::max(RANK - MATMUL_BATCH_RANK_THRESHOLD, 0),
c_shape.begin() + (RANK - 2),
static_cast<size_t>(1),
std::multiplies<size_t>()));
// Calculate strides for A tensor if it has matching rank
if constexpr (TensorTypeA::Rank() == RANK) {
params.astride = a.Stride(TensorTypeA::Rank()-3);
}
else {
params.astride = 0;
}
// Calculate strides for B tensor if it has matching rank
if constexpr (TensorTypeB::Rank() == RANK) {
params.bstride = b.Stride(TensorTypeB::Rank()-3);
}
else {
params.bstride = 0;
}
// Calculate stride for C tensor
params.cstride = c.Stride(RANK-3);
}
// If the user wants C transposed (as a permuted view), we need the output
// matrix to still be MxN in memory. The reason is the permuted view will
// handle viewing it as an NxM. To accomplish this we use the identity C' =
// B'A', so we swap A and B and permute them.
if (c.Stride(RANK - 2) == 1 && c.Size(RANK - 1) != 1) {
// TODO this looks like repeat logic from what I put in elsewhere...
// track this down later. For now adding an assert to see if it ever pops up.
// If it does not we should delete this code.
MATX_ASSERT_STR(false, matxInvalidDim, "Internal Matmul error. This should not be hit\n");
if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) {
if (b.Stride(TensorTypeB::Rank() - 2) == 1) {
params.opA = CUBLAS_OP_N;
params.a_rows = b.Size(TensorTypeB::Rank() - 1);
params.a_cols = b.Size(TensorTypeB::Rank() - 2);
params.lda = b.Stride(TensorTypeB::Rank() - 1);
}
else if (b.Stride(TensorTypeB::Rank() - 1) == 1) {
params.opA = CUBLAS_OP_T;
params.a_rows = b.Size(TensorTypeB::Rank() - 2);
params.a_cols = b.Size(TensorTypeB::Rank() - 1);
params.lda = b.Stride(TensorTypeB::Rank() - 2);
}
if (a.Stride(TensorTypeA::Rank() - 2) == 1) {
params.opB = CUBLAS_OP_N;
params.b_rows = a.Size(TensorTypeA::Rank() - 1);
params.b_cols = a.Size(TensorTypeA::Rank() - 2);
params.ldb = a.Stride(TensorTypeA::Rank() - 1);
}
else if (a.Stride(TensorTypeA::Rank() - 1) == 1) {
params.opB = CUBLAS_OP_T;
params.b_rows = a.Size(TensorTypeA::Rank() - 2);
params.b_cols = a.Size(TensorTypeA::Rank() - 1);
params.ldb = a.Stride(TensorTypeA::Rank() - 2);
}
params.c_rows = params.a_rows;
params.c_cols = params.b_cols;
params.ldc = c.Stride(RANK - 1);
}
else if constexpr (PROV == PROVIDER_TYPE_CUTLASS) {
params.opA = CUBLAS_OP_N;
params.opB = CUBLAS_OP_N;
params.m = static_cast<int>(b.Size(TensorTypeB::Rank() - 1));
params.n = static_cast<int>(a.Size(TensorTypeA::Rank() - 2));
params.k =
static_cast<int>(a.Size(TensorTypeA::Rank() - 2)); // Gemm Problem dimensions
params.lda = static_cast<int>(b.Stride(TensorTypeB::Rank() - 1));
params.ldb = static_cast<int>(a.Stride(TensorTypeA::Rank() - 1));
params.ldc = static_cast<int>(c.Stride(RANK - 1));
}
}
else {
if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) {
if constexpr (is_complex_half_v<typename TensorTypeA::value_type>) {
// For half complex we always copy to a new tensor so it is always cublas op N
params.opA = CUBLAS_OP_N;
} else if ( a.Stride(TensorTypeA::Rank()-1) > 1 // last stride > 1
|| (a.Stride(TensorTypeA::Rank()-1) == 1 && a.Stride(TensorTypeA::Rank()-2) == 1 && a.Size(TensorTypeA::Rank()-1) != 1)) { // last strides both equal 1 and size > 1
params.opA = CUBLAS_OP_T;
} else { // otherwise row major
params.opA = CUBLAS_OP_N;
}
if constexpr (is_complex_half_v<typename TensorTypeB::value_type>) {
// For half complex we always copy to a new tensor so it is always cublas op N
params.opB = CUBLAS_OP_N;
} else if ( b.Stride(TensorTypeB::Rank()-1) > 1 // last stride > 1
|| (b.Stride(TensorTypeB::Rank()-1) == 1 && b.Stride(TensorTypeB::Rank()-2) == 1 && b.Size(TensorTypeB::Rank()-1) != 1)) { // last strides both equal 1 and size > 1
params.opB = CUBLAS_OP_T;
} else { // otherwise row major
params.opB = CUBLAS_OP_N;
}
params.a_rows = a.Size(TensorTypeA::Rank() - 2);
params.a_cols = a.Size(TensorTypeA::Rank() - 1);
params.b_rows = b.Size(TensorTypeB::Rank() - 2);
params.b_cols = b.Size(TensorTypeB::Rank() - 1);
// set lda/ldb according to transpose modes. If we pass in a cloned tensor the second stride will be
// 0, which cuBLAS doesn't like even though it's unused. Set it to something that it would be if the
// matrix had more than 1 row.
if (params.opB == CUBLAS_OP_T) {
params.ldb = b.Stride(TensorTypeB::Rank() - 1);
}
else {
params.ldb = b.Stride(TensorTypeB::Rank() - 2);
params.ldb = (params.ldb == 0) ? b.Size(TensorTypeB::Rank() - 1) : params.ldb;
}
if (params.opA == CUBLAS_OP_T) {
params.lda = a.Stride(TensorTypeA::Rank() - 1);
}
else {
params.lda = a.Stride(TensorTypeA::Rank() - 2);
params.lda = (params.lda == 0) ? a.Size(TensorTypeA::Rank() - 1) : params.lda;
}
// for complex half we have copied to planar row-major
if (is_complex_half_v<typename TensorTypeB::value_type>) {
params.ldb = b.Size(TensorTypeB::Rank()-1);
}
// for complex half we have copied to planar row-major
if constexpr (is_complex_half_v<typename TensorTypeB::value_type>) {
params.lda = a.Size(TensorTypeA::Rank()-1);
}
params.c_rows = params.a_rows;
params.c_cols = params.b_cols;
params.ldc = c.Stride(RANK - 2);
}
else if constexpr (PROV == PROVIDER_TYPE_CUTLASS) {
params.opA = CUBLAS_OP_N;
params.opB = CUBLAS_OP_N;
params.m = static_cast<int>(a.Size(TensorTypeA::Rank() - 2));
params.n = static_cast<int>(b.Size(TensorTypeB::Rank() - 1));
params.k =
static_cast<int>(a.Size(TensorTypeA::Rank() - 1)); // Gemm Problem dimensions
params.lda = static_cast<int>(a.Stride(TensorTypeA::Rank() - 2));
params.ldb = static_cast<int>(b.Stride(TensorTypeB::Rank() - 2));
params.ldc = static_cast<int>(c.Stride(RANK - 2));
}
}
return params;
}
/**
* GEMM handle destructor
*
* Destroys any helper data used for provider type and any workspace memory
* created
*
*/
~MatMulCUDAHandle_t()
{
matxFree(workspace);
if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) {
cublasLtMatmulPreferenceDestroy(preference);
cublasLtMatrixLayoutDestroy(Cdesc);
cublasLtMatrixLayoutDestroy(Bdesc);
cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatmulDescDestroy(operationDesc);
}
matxFree(a_hp);
matxFree(b_hp);
matxFree(c_hp);
}
/**
* Execute a Matrix multiply (GEMM)
*
* Execute a matrix multiply operation on two rank=2 input tensors into an
* output tensor. Using BLAS notation, tensor A has dimensions MxK, B is KxN,
* and C is MxN. Concretely:
*
* \f$\textbf{C} = \alpha\textbf{A}\textbf{B} + \beta\textbf{C}\f$
*
* MatX will perform runtime checks ensuring that the dimension constraints are
* met on all views. Unlike BLAS GEMMS, most parameters of the GEMM call are
* deduced from the view itself; there is no need to specify dimensions or
* transpose operations. MatX will attempt to perform the GEMM in the most
* efficient way possible given the knowledge of the view.
*
* While GEMMs are strictly rank=2 functions, rank 3 and higher tensors may be
* passed to this function, which has the effect of batching across the higher
* dimensions.
*
* @note views being passed to matxGemm must not be permuted and must have a
* contigous stride currently.
*
* @param c
* Output tensor C
* @param a
* Input tensor A
* @param b
* Input tensor B
* @param stream
* CUDA stream
* @param alpha
* Alpha value
* @param beta
* Beta value
*
*/
__MATX_INLINE__ void Exec(TensorTypeC &c, const TensorTypeA &a,
const TensorTypeB &b, cudaStream_t stream,
float alpha = 1.0f, float beta = 0.0f)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
// Reorder C/A to match cutlass API
MatMulDispatchA(a, b, c, stream, alpha, beta);
}
private:
// Member variables
cublasLtHandle_t ltHandle;
cublasStatus_t ret = CUBLAS_STATUS_SUCCESS;
// cuBLASLt variables;
cublasHandle_t handle;
cublasLtMatmulDesc_t operationDesc = nullptr;
cublasLtMatrixLayout_t Adesc = nullptr;
cublasLtMatrixLayout_t Bdesc = nullptr;
cublasLtMatrixLayout_t Cdesc = nullptr;
cublasLtMatmulPreference_t preference = nullptr;
cublasLtMatrixTransformDesc_t transformDescI = nullptr;
cublasLtMatrixTransformDesc_t transformDescO = nullptr;
cublasLtMatrixLayout_t AtransformDesc = nullptr;
cublasLtMatrixLayout_t BtransformDesc = nullptr;
cublasLtMatrixLayout_t CtransformDesc = nullptr;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
void *c_hp = nullptr; // Make these void since they only work on complex types
void *a_hp = nullptr;
void *b_hp = nullptr;
size_t workspaceSize = 0;
void *workspace = nullptr;
detail::MatMulCUDAParams_t params_;
void ConfigureCublasLt()
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
ret = cublasLtCreate(<Handle);
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
ret = cublasLtMatmulPreferenceCreate(&preference);
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
ret = cublasLtMatmulDescCreate(
&operationDesc, MatXTypeToCudaComputeType<T1>(),
MatXTypeToCudaType<T1>());
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
ret = cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize,
sizeof(workspaceSize));
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
cublasLtOrder_t rowOrder = CUBLASLT_ORDER_ROW;
cublasLtOrder_t colOrder = CUBLASLT_ORDER_COL;
auto op = CUBLAS_OP_N;
// A operation
ret = cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op,
sizeof(op));
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
// B operation
ret = cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op,
sizeof(op));
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
// Update this later when we're more flexible on compute type
int32_t scaleType;
if constexpr (std::is_same_v<T1, float> || is_matx_half_v<T1>) {
scaleType = CUDA_R_32F;
}
else if constexpr (is_complex_half_v<T1> ||
std::is_same_v<T1, cuda::std::complex<float>>) {
scaleType = CUDA_C_32F;
}
else if constexpr (std::is_same_v<T1, cuda::std::complex<double>>) {
scaleType = CUDA_C_64F;
}
else {
scaleType = CUDA_R_64F;
}
ret = cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scaleType,
sizeof(scaleType));
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
// Matrix layouts
ret = cublasLtMatrixLayoutCreate(
&Adesc, MatXTypeToCudaType<T2>(), params_.a_rows,
params_.a_cols, params_.lda);
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
ret = cublasLtMatrixLayoutCreate(
&Bdesc, MatXTypeToCudaType<T3>(), params_.b_rows,
params_.b_cols, params_.ldb);
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
ret = cublasLtMatrixLayoutCreate(
&Cdesc, MatXTypeToCudaType<T1>(), params_.c_rows,
params_.c_cols, params_.ldc);
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
// Matrix data order
if (params_.opA == CUBLAS_OP_T) {
ret = cublasLtMatrixLayoutSetAttribute(
Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &colOrder,
sizeof(colOrder));
}
else {
ret = cublasLtMatrixLayoutSetAttribute(
Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &rowOrder,
sizeof(rowOrder));
}
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
if (params_.opB == CUBLAS_OP_T) {
ret = cublasLtMatrixLayoutSetAttribute(
Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &colOrder,
sizeof(colOrder));
}
else {
ret = cublasLtMatrixLayoutSetAttribute(
Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &rowOrder,
sizeof(rowOrder));
}
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
ret = cublasLtMatrixLayoutSetAttribute(
Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &rowOrder,
sizeof(rowOrder));
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
ret = cublasLtMatrixLayoutSetAttribute(
Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, ¶ms_.batch,
sizeof(params_.batch));
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
ret = cublasLtMatrixLayoutSetAttribute(
Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, ¶ms_.batch,
sizeof(params_.batch));
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
ret = cublasLtMatrixLayoutSetAttribute(
Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, ¶ms_.batch,
sizeof(params_.batch));
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
int64_t stride;
if constexpr (is_complex_half_v<T2>) {
// for complex half we have copied to planar row major
// we know the layout of this matrix is compact
if constexpr (TensorTypeA::Rank() == RANK) {
stride = params_.a_rows * params_.a_cols * 2;
}
else {
stride = 0;
}
}
else {
stride = params_.astride;
}
ret = cublasLtMatrixLayoutSetAttribute(
Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride,
sizeof(stride));
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
if constexpr (is_complex_half_v<T3>) {
// for complex half we have copied to planar row major
// we know the layout of this matrix is compact
if constexpr (TensorTypeB::Rank() == RANK) {
stride = params_.b_rows * params_.b_cols * 2;
}
else {
stride = 0;
}
}
else {
stride = params_.bstride;
}
ret = cublasLtMatrixLayoutSetAttribute(
Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride,
sizeof(stride));
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
if constexpr (is_complex_half_v<T1>) {
// for complex half we have copied to planar row major
// we know the layout of this matrix is compact
stride = params_.c_rows * params_.c_cols * 2;
}
else {
stride = params_.cstride;
}
ret = cublasLtMatrixLayoutSetAttribute(
Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride,
sizeof(stride));
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
if constexpr (is_complex_half_v<T1> && is_complex_half_v<T2>) {
// for complex half we have copied to planar row major
size_t planarA = (params_.a_rows * params_.a_cols * sizeof(T1)) / 2;
size_t planarB = (params_.b_rows * params_.b_cols * sizeof(T1)) / 2;
size_t planarC = (params_.c_rows * params_.c_cols * sizeof(T1)) / 2;
ret = cublasLtMatrixLayoutSetAttribute(
Adesc, CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET, &planarA,
sizeof(planarA));
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
ret = cublasLtMatrixLayoutSetAttribute(
Bdesc, CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET, &planarB,
sizeof(planarB));
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
ret = cublasLtMatrixLayoutSetAttribute(
Cdesc, CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET, &planarC,
sizeof(planarC));
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
}
int res;
ret = cublasLtMatmulAlgoGetHeuristic(ltHandle, operationDesc, Adesc,
Bdesc, Cdesc, Cdesc, preference,
1, &heuristicResult,
&res);
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError);
MATX_ASSERT(res > 0, matxMatMulError);
}
// TODO: Fix the unused parameters once we support mixes of col/row on cublas
template <MemOrder_t OrderA, MemOrder_t OrderB, MemOrder_t OrderC>
__MATX_INLINE__ void
MatMulLaunch(const TensorTypeA &a, const TensorTypeB &b,
TensorTypeC &c, cudaStream_t stream,
[[maybe_unused]] float alpha, [[maybe_unused]] float beta)
{
MATX_ASSERT_STR(PROV < PROVIDER_TYPE_SENTINEL, matxInvalidParameter, "Provider type out of range");
if constexpr ((PROV == PROVIDER_TYPE_CUTLASS) &&
(is_complex_half_v<T1> || is_complex_half_v<T2>)) {
MATX_THROW(matxInvalidType,
"CUTLASS does not support complex fp16/bf16 yet");
}
if constexpr ((is_complex_half_v<T1> && !is_complex_half_v<T2>) ||
(is_complex_half_v<T2> && !is_complex_half_v<T3>) ||
(is_complex_half_v<T1> && !is_complex_half_v<T3>)) {
MATX_THROW(matxInvalidType,
"A/B/C types must all be half complex if any of them are");
}
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
// Make copies of each tensor in case we have to do a transformation before
// the GEMM
[[maybe_unused]] TensorTypeA a_adj { a };
[[maybe_unused]] TensorTypeB b_adj { b };
[[maybe_unused]] TensorTypeC c_adj { c };
// If the tensors are complex half precision, we need to do a planar
// transform since all libraries expect this format at the moment.
if constexpr (is_complex_half_v<T1>) {
auto a_shape = a.Shape();
*(a_shape.begin() + a.Rank() - 2) = a.Size(a.Rank() - 2) * 2;
if (a_hp == nullptr) {
matxAlloc(&a_hp, a.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream);
}
auto a_planar = make_tensor<typename T2::value_type>(reinterpret_cast<typename T2::value_type*>(a_hp), a_shape, false);
auto b_shape = b.Shape();
*(b_shape.begin() + b.Rank() - 2) = b.Size(b.Rank() - 2) * 2;
if (b_hp == nullptr) {
matxAlloc(&b_hp, b.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream);
}
auto b_planar = make_tensor<typename T3::value_type>(reinterpret_cast<typename T3::value_type*>(b_hp), b_shape, false);
auto c_shape = c.Shape();
*(c_shape.begin() + c.Rank() - 2) = c.Size(c.Rank() - 2) * 2;
if (c_hp == nullptr) {
matxAlloc(&c_hp, c.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream);
}
auto c_planar = make_tensor<typename T1::value_type>(reinterpret_cast<typename T1::value_type*>(c_hp), c_shape, false);
// Convert A/B to planar layout
(a_planar = planar(a)).run(stream);
(b_planar = planar(b)).run(stream);
// update pointers to planar data.
// must use Reset because types for planar are different
a_adj.Reset(reinterpret_cast<T1 *>(a_planar.Data()));
b_adj.Reset(reinterpret_cast<T2 *>(b_planar.Data()));
c_adj.Reset(reinterpret_cast<T3 *>(c_planar.Data()));
}
// Prep for batch looping
using shape_type = typename TensorTypeA::desc_type::shape_type;
[[maybe_unused]] cuda::std::array<shape_type, TensorTypeA::Rank()> a_idx{0};
[[maybe_unused]] cuda::std::array<shape_type, TensorTypeB::Rank()> b_idx{0};
[[maybe_unused]] cuda::std::array<shape_type, TensorTypeC::Rank()> c_idx{0};
[[maybe_unused]] auto a_shape = a.Shape();
[[maybe_unused]] size_t total_iter = 1;
// For rank > threshold, we loop and process the innermost "threshold" number of dimensions at each iteration.
if constexpr (RANK > MATMUL_BATCH_RANK_THRESHOLD) {
// Get total number of iterations needed for dimensions beyond the first two batch dims
[[maybe_unused]] auto c_shape = c.Shape();
total_iter = std::accumulate(c_shape.begin(),
c_shape.begin() + TensorTypeC::Rank() - MATMUL_BATCH_RANK_THRESHOLD,
static_cast<size_t>(1),
std::multiplies<size_t>());
}
// For cuBLASLt most of the parameters have already been set in the
// configure stage
if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) {
MatMulScaleType_t salpha, sbeta;
memset(&salpha, 0, sizeof(salpha));
memset(&sbeta, 0, sizeof(sbeta));
if constexpr (std::is_same_v<T1, cuda::std::complex<float>> ||
is_complex_half_v<T1>) {
salpha.cf32[0] = alpha;
sbeta.cf32[0] = beta;
}
else if constexpr (std::is_same_v<T1, cuda::std::complex<double>>) {
salpha.cf64[0] = alpha;
sbeta.cf64[0] = beta;
}
else if constexpr (std::is_same_v<T1, float> || is_matx_half_v<T1>) {
salpha.f32 = alpha;
sbeta.f32 = beta;
}
else if constexpr (std::is_same_v<T1, double>) {
salpha.f64 = alpha;
sbeta.f64 = beta;
}
if constexpr (RANK <= MATMUL_BATCH_RANK_THRESHOLD) {
// For ranks up to threshold, we can handle everything in a single batch operation
[[maybe_unused]] auto res = cublasLtMatmul(
ltHandle, operationDesc, &salpha, (void *)a_adj.Data(), Adesc,
(void *)b_adj.Data(), Bdesc, &sbeta, (void *)c_adj.Data(), Cdesc,
(void *)c_adj.Data(), Cdesc, &heuristicResult.algo, workspace,
workspaceSize, stream);
MATX_ASSERT(res == CUBLAS_STATUS_SUCCESS, matxMatMulError);
}
else {
// When rank exceeds threshold, we loop over the outer dimensions where each iteration
// of cublasLtMatMul processes the innermost 'threshold' number of dimensions
for (size_t iter = 0; iter < total_iter; iter++) {
// Get pointers into A/B/C for this round
auto ap = cuda::std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, a_idx);
auto bp = cuda::std::apply([&b_adj](auto... param) { return b_adj.GetPointer(param...); }, b_idx);
auto cp = cuda::std::apply([&c_adj](auto... param) { return c_adj.GetPointer(param...); }, c_idx);
[[maybe_unused]] auto res = cublasLtMatmul(
ltHandle, operationDesc, &salpha, (void *)ap,
Adesc, (void *)bp, Bdesc, &sbeta,
(void *)cp, Cdesc, (void *)cp,
Cdesc, &heuristicResult.algo, workspace, workspaceSize,
stream);
MATX_ASSERT(res == CUBLAS_STATUS_SUCCESS, matxMatMulError);
// Update all but the last 4 indices (2 matrix dims + 2 batch dims)
UpdateIndices<TensorTypeA, shape_type, TensorTypeA::Rank()>(a_adj, a_idx, MATMUL_BATCH_RANK_THRESHOLD);
UpdateIndices<TensorTypeB, shape_type, TensorTypeB::Rank()>(b_adj, b_idx, MATMUL_BATCH_RANK_THRESHOLD);
UpdateIndices<TensorTypeC, shape_type, TensorTypeC::Rank()>(c_adj, c_idx, MATMUL_BATCH_RANK_THRESHOLD);
}
}
}
if constexpr (RANK == 2) {
if constexpr (PROV == PROVIDER_TYPE_CUTLASS) {
#ifdef MATX_ENABLE_CUTLASS
using CutlassAOrder = std::conditional_t<OrderA == MEM_ORDER_ROW_MAJOR,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor>;
using CutlassBOrder = std::conditional_t<OrderB == MEM_ORDER_ROW_MAJOR,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor>;
using CutlassCOrder = std::conditional_t<OrderC == MEM_ORDER_ROW_MAJOR,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor>;
using CutlassGemm =
cutlass::gemm::device::Gemm<T1, // Data-type of A matrix
CutlassAOrder, // Layout of A matrix
T2, // Data-type of B matrix
CutlassBOrder, // Layout of B matrix
T3, // Data-type of C matrix
CutlassCOrder>; // Layout of C matrix
typename CutlassGemm::Arguments args(
{static_cast<int>(params_.m), static_cast<int>(params_.n),
static_cast<int>(params_.k)}, // Gemm Problem dimensions
{a.Data(),
static_cast<int>(params_.lda)}, // Tensor-ref for source matrix A
{b.Data(),
static_cast<int>(params_.ldb)}, // Tensor-ref for source matrix B
{c.Data(),
static_cast<int>(params_.ldc)}, // Tensor-ref for source matrix C
{c.Data(),
static_cast<int>(
params_.ldc)}, // Tensor-ref for destination matrix D (may be
// different memory than source C matrix)
{static_cast<T1>(alpha), static_cast<T1>(beta)}); // Scalars used in the Epilogue
CutlassGemm gemm_operator;
cutlass::Status status = gemm_operator(args, nullptr, stream);
MATX_ASSERT(status == cutlass::Status::kSuccess, matxMatMulError);
#else
MATX_THROW(matxNotSupported, "CUTLASS not enabled!");
#endif
}
}
else {
static_assert(RANK > 2);
#ifdef MATX_ENABLE_CUTLASS
using CutlassAOrder = std::conditional_t<OrderA == MEM_ORDER_ROW_MAJOR,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor>;
using CutlassBOrder = std::conditional_t<OrderB == MEM_ORDER_ROW_MAJOR,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor>;
using CutlassCOrder = std::conditional_t<OrderC == MEM_ORDER_ROW_MAJOR,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor>;
using CutlassGemm = cutlass::gemm::device::GemmBatched<
T1, // Data-type of A matrix
CutlassAOrder, // Layout of A matrix
T2, // Data-type of B matrix
CutlassBOrder, // Layout of B matrix
T3, // Data-type of C matrix
CutlassCOrder>; // Layout of C matrix
#endif
if constexpr (RANK > MATMUL_BATCH_RANK_THRESHOLD) {
if constexpr (PROV == PROVIDER_TYPE_CUTLASS) {
#ifdef MATX_ENABLE_CUTLASS
for (size_t iter = 0; iter < total_iter; iter++) {
// Get pointers into A/B/C for this round
auto ap = cuda::std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, a_idx);
auto bp = cuda::std::apply([&b_adj](auto... param) { return b_adj.GetPointer(param...); }, b_idx);
auto cp = cuda::std::apply([&c_adj](auto... param) { return c_adj.GetPointer(param...); }, c_idx);
typename CutlassGemm::Arguments args(
{static_cast<int>(params_.m), static_cast<int>(params_.n),
static_cast<int>(params_.k)}, // Gemm Problem dimensions
{ap,
static_cast<int>(
params_.lda)}, // Tensor-ref for source matrix A
a_adj.Stride(RANK - 3), // Batch Stride A
{bp,
static_cast<int>(
params_.ldb)}, // Tensor-ref for source matrix B
b_adj.Stride(RANK - 3), // Batch Stride B
{cp,
static_cast<int>(
params_.ldc)}, // Tensor-ref for source matrix C
c_adj.Stride(RANK - 3), // Batch Stride C
{cp,
static_cast<int>(
params_.ldc)}, // Tensor-ref for destination matrix D (may
// be different memory than source C matrix)
c_adj.Stride(RANK - 3), // Batch Stride C
{static_cast<T1>(alpha), static_cast<T1>(beta)},
params_.batch // Batch Dimension
); // Scalars used in the Epilogue
CutlassGemm gemm_operator;
cutlass::Status status = gemm_operator(args, nullptr, stream);
MATX_ASSERT(status == cutlass::Status::kSuccess, matxMatMulError);
// Update all but the last 2 indices
UpdateIndices<TensorTypeA, shape_type, TensorTypeA::Rank()>(a_adj, a_idx, 3);
}
#else
MATX_THROW(matxNotSupported, "CUTLASS not enabled!");
#endif
}
else {
MATX_STATIC_ASSERT_STR(PROV < PROVIDER_TYPE_SENTINEL, matxInvalidParameter, "Invalid MatMul provider");
}
}
}
// If the tensors are complex half precisions, we need to convert C back to
// interleaved format and free all temporary buffers