Skip to content

Commit 39ca590

Browse files
authored
【Inference】Modify Layout assert (#74415)
1 parent 9e2272c commit 39ca590

File tree

1 file changed

+14
-10
lines changed
  • paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel

1 file changed

+14
-10
lines changed

paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,6 @@ struct GemmFpAIntB {
7676
using LayoutC = typename Mma::LayoutC;
7777
using ElementScale = typename Mma::IteratorA::Element;
7878

79-
// NOTE: (changwenbin) Currently only A row major and B column major are
80-
// supported. Other cases have not been verified yet.
81-
82-
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
83-
static_assert(
84-
platform::is_same<LayoutA, layout::RowMajor>::value &&
85-
platform::is_same<LayoutB, layout::ColumnMajor>::value,
86-
"A must be row major and B must be col major in cuda_arch >= sm75");
87-
#endif
88-
8979
static ComplexTransform const kTransformA = Mma::kTransformA;
9080
static ComplexTransform const kTransformB = Mma::kTransformA;
9181

@@ -97,6 +87,20 @@ struct GemmFpAIntB {
9787
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
9888
using ArchTag = typename Mma::ArchTag;
9989

90+
// NOTE: (changwenbin) Ensure that the GEMM layout meets the requirements
91+
// under different architectures.
92+
static_assert(
93+
(platform::is_same<ArchTag, arch::Sm75>::value &&
94+
platform::is_same<LayoutA, layout::RowMajor>::value &&
95+
platform::is_same<LayoutB, layout::ColumnMajor>::value) ||
96+
(platform::is_same<ArchTag, arch::Sm80>::value &&
97+
platform::is_same<LayoutA, layout::RowMajor>::value &&
98+
platform::is_same<LayoutB, layout::ColumnMajor>::value) ||
99+
(platform::is_same<ArchTag, arch::Sm70>::value &&
100+
platform::is_same<LayoutA, layout::RowMajor>::value &&
101+
platform::is_same<LayoutB, layout::RowMajor>::value),
102+
"A must be row major and B must be col major in cuda_arch >= sm75");
103+
100104
static int const kStages = Mma::kStages;
101105
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
102106
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;

0 commit comments

Comments
 (0)