Skip to content

Commit 4ec14dc

Browse files
authored
[SYCLomatic] Fix a bug that wmma APIs not migrated in template function (#2424)
Signed-off-by: Ziran Zhang <[email protected]>
1 parent ec7f43d commit 4ec14dc

File tree

3 files changed

+24
-13
lines changed

3 files changed

+24
-13
lines changed

clang/lib/DPCT/ExprAnalysis.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -859,15 +859,7 @@ void ExprAnalysis::analyzeExpr(const UnaryExprOrTypeTraitExpr *UETT) {
859859
}
860860

861861
inline void ExprAnalysis::analyzeExpr(const UnresolvedLookupExpr *ULE) {
862-
RefString.clear();
863-
llvm::raw_string_ostream OS(RefString);
864-
if (auto NNS = ULE->getQualifier()) {
865-
if (NNS->getKind() != clang::NestedNameSpecifier::SpecifierKind::Global) {
866-
NNS->print(OS, dpct::DpctGlobalInfo::getContext().getPrintingPolicy());
867-
}
868-
}
869-
ULE->getName().print(OS,
870-
dpct::DpctGlobalInfo::getContext().getPrintingPolicy());
862+
RefString = ULE->decls().begin().getDecl()->getQualifiedNameAsString();
871863
}
872864

873865
void ExprAnalysis::analyzeExpr(const ExplicitCastExpr *Cast) {

clang/lib/DPCT/WMMAAPIMigration.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,22 @@ using namespace clang::dpct;
1313
using namespace clang::ast_matchers;
1414

1515
void clang::dpct::WMMARule::registerMatcher(ast_matchers::MatchFinder &MF) {
16+
auto FuncName = []() {
17+
return hasAnyName("fill_fragment", "load_matrix_sync", "mma_sync",
18+
"store_matrix_sync");
19+
};
1620
MF.addMatcher(
1721
typeLoc(loc(qualType(hasDeclaration(namedDecl(allOf(
1822
hasAnyName("fragment", "matrix_a", "matrix_b", "row_major",
1923
"col_major", "accumulator", "layout_t"),
2024
hasDeclContext(namespaceDecl(hasName("wmma")))))))))
2125
.bind("type"),
2226
this);
23-
MF.addMatcher(callExpr(callee(functionDecl(allOf(
24-
hasAnyName("fill_fragment", "load_matrix_sync",
25-
"mma_sync", "store_matrix_sync"),
26-
hasDeclContext(namespaceDecl(hasName("wmma")))))))
27+
MF.addMatcher(callExpr(anyOf(callee(functionDecl(allOf(
28+
FuncName(), hasDeclContext(namespaceDecl(
29+
hasName("wmma")))))),
30+
callee(unresolvedLookupExpr(
31+
hasAnyDeclaration(namedDecl(FuncName()))))))
2732
.bind("call"),
2833
this);
2934
MF.addMatcher(declRefExpr(to(enumConstantDecl(allOf(

clang/test/dpct/wmma.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,5 +201,19 @@ int main() {
201201

202202
return 0;
203203
}
204+
205+
using namespace nvcuda;
206+
template<typename T>
207+
__global__ void simple_wmma_gemm(T *d) {
208+
wmma::fragment<wmma::accumulator, 16, 16, 16, T> c_frag;
209+
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(item_ct1.get_sub_group(), c_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, T>(d), 1, sycl::ext::oneapi::experimental::matrix::layout::row_major);
210+
wmma::store_matrix_sync(d, c_frag, 1, wmma::mem_row_major);
211+
}
212+
int main() {
213+
simple_wmma_gemm<half><<<1, 1>>>(nullptr);
214+
simple_wmma_gemm<float><<<1, 1>>>(nullptr);
215+
return 0;
216+
}
217+
204218
// clang-format on
205219
#endif

0 commit comments

Comments
 (0)