@@ -76,16 +76,6 @@ struct GemmFpAIntB {
76
76
using LayoutC = typename Mma::LayoutC;
77
77
using ElementScale = typename Mma::IteratorA::Element;
78
78
79
- // NOTE: (changwenbin) Currently only A row major and B column major are
80
- // supported. Other cases have not been verified yet.
81
-
82
- #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
83
- static_assert (
84
- platform::is_same<LayoutA, layout::RowMajor>::value &&
85
- platform::is_same<LayoutB, layout::ColumnMajor>::value,
86
- " A must be row major and B must be col major in cuda_arch >= sm75" );
87
- #endif
88
-
89
79
static ComplexTransform const kTransformA = Mma::kTransformA ;
90
80
static ComplexTransform const kTransformB = Mma::kTransformA ;
91
81
@@ -97,6 +87,20 @@ struct GemmFpAIntB {
97
87
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
98
88
using ArchTag = typename Mma::ArchTag;
99
89
90
+ // NOTE: (changwenbin) Ensure that the GEMM layout meets the requirements
91
+ // under different architectures.
92
+ static_assert (
93
+ (platform::is_same<ArchTag, arch::Sm75>::value &&
94
+ platform::is_same<LayoutA, layout::RowMajor>::value &&
95
+ platform::is_same<LayoutB, layout::ColumnMajor>::value) ||
96
+ (platform::is_same<ArchTag, arch::Sm80>::value &&
97
+ platform::is_same<LayoutA, layout::RowMajor>::value &&
98
+ platform::is_same<LayoutB, layout::ColumnMajor>::value) ||
99
+ (platform::is_same<ArchTag, arch::Sm70>::value &&
100
+ platform::is_same<LayoutA, layout::RowMajor>::value &&
101
+ platform::is_same<LayoutB, layout::RowMajor>::value),
102
+ " A must be row major and B must be col major in cuda_arch >= sm75" );
103
+
100
104
static int const kStages = Mma::kStages ;
101
105
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements ;
102
106
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements ;
0 commit comments