@@ -350,34 +350,6 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
350350 return ResultType;
351351}
352352
353- template <bool NeedTypeInterpret = false >
354- llvm::Type *getJointMatrixINTELExtType (llvm::Type *CompTy,
355- ArrayRef<TemplateArgument> TemplateArgs,
356- const unsigned Val = 0 ) {
357- // TODO: we should actually have exactly 5 template parameters: 1 for
358- // type and 4 for type parameters. But in previous version of the SPIR-V
359- // spec we have Layout matrix type parameter, that was later removed.
360- // Once we update to the newest version of the spec - this should be updated.
361- assert ((TemplateArgs.size () == 5 || TemplateArgs.size () == 6 ) &&
362- " Wrong JointMatrixINTEL template parameters number" );
363- // This is required to represent optional 'Component Type Interpretation'
364- // parameter
365- std::vector<unsigned > Params;
366- for (size_t I = 1 ; I != TemplateArgs.size (); ++I) {
367- assert (TemplateArgs[I].getKind () == TemplateArgument::Integral &&
368- " Wrong JointMatrixINTEL template parameter" );
369- Params.push_back (TemplateArgs[I].getAsIntegral ().getExtValue ());
370- }
371- // Don't add type interpretation for legacy matrices.
372- // Legacy matrices has 5 template parameters, while new representation
373- // has 6.
374- if (NeedTypeInterpret && TemplateArgs.size () != 5 )
375- Params.push_back (Val);
376-
377- return llvm::TargetExtType::get (CompTy->getContext (),
378- " spirv.JointMatrixINTEL" , {CompTy}, Params);
379- }
380-
381353llvm::Type *
382354getCooperativeMatrixKHRExtType (llvm::Type *CompTy,
383355 ArrayRef<TemplateArgument> TemplateArgs) {
@@ -394,49 +366,6 @@ getCooperativeMatrixKHRExtType(llvm::Type *CompTy,
394366 CompTy->getContext (), " spirv.CooperativeMatrixKHR" , {CompTy}, Params);
395367}
396368
397- // / ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
398- // / which is represented as a pointer to a structure to LLVM extension type
399- // / with the parameters that follow SPIR-V JointMatrixINTEL type.
400- // / The expected representation is:
401- // / target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
402- // / %use%, (optional) %element_type_interpretation%)
403- llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType (RecordDecl *RD) {
404- auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
405- ArrayRef<TemplateArgument> TemplateArgs =
406- TemplateDecl->getTemplateArgs ().asArray ();
407- assert (TemplateArgs[0 ].getKind () == TemplateArgument::Type &&
408- " 1st JointMatrixINTEL template parameter must be type" );
409- llvm::Type *CompTy = ConvertType (TemplateArgs[0 ].getAsType ());
410-
411- // Per JointMatrixINTEL spec the type can have an optional
412- // 'Component Type Interpretation' parameter. We should emit it in case
413- // if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as
414- // matrix's components. Yet 'bfloat16' should be represented as 'int16' and
415- // 'tf32' as 'float' types.
416- if (CompTy->isStructTy ()) {
417- StringRef LlvmTyName = CompTy->getStructName ();
418- // Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
419- if (LlvmTyName.starts_with (" class.sycl::" ) ||
420- LlvmTyName.starts_with (" class.__sycl_internal::" ))
421- LlvmTyName = LlvmTyName.rsplit (" ::" ).second ;
422- if (LlvmTyName == " half" ) {
423- CompTy = llvm::Type::getHalfTy (getLLVMContext ());
424- return getJointMatrixINTELExtType (CompTy, TemplateArgs);
425- } else if (LlvmTyName == " tf32" ) {
426- CompTy = llvm::Type::getFloatTy (getLLVMContext ());
427- // 'tf32' interpretation is mapped to '0'
428- return getJointMatrixINTELExtType<true >(CompTy, TemplateArgs, 0 );
429- } else if (LlvmTyName == " bfloat16" ) {
430- CompTy = llvm::Type::getInt16Ty (getLLVMContext ());
431- // 'bfloat16' interpretation is mapped to '1'
432- return getJointMatrixINTELExtType<true >(CompTy, TemplateArgs, 1 );
433- } else {
434- llvm_unreachable (" Wrong matrix base type!" );
435- }
436- }
437- return getJointMatrixINTELExtType (CompTy, TemplateArgs);
438- }
439-
440369// / ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
441370// / which is represented as a pointer to a structure to LLVM extension type
442371// / with the parameters that follow SPIR-V CooperativeMatrixKHR type.
@@ -739,11 +668,7 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
739668 if (ClangETy && ClangETy->isStructureOrClassType ()) {
740669 RecordDecl *RD = ClangETy->getAsCXXRecordDecl ();
741670 if (RD && RD->getQualifiedNameAsString () ==
742- " __spv::__spirv_JointMatrixINTEL" ) {
743- ResultType = ConvertSYCLJointMatrixINTELType (RD);
744- break ;
745- } else if (RD && RD->getQualifiedNameAsString () ==
746- " __spv::__spirv_CooperativeMatrixKHR" ) {
671+ " __spv::__spirv_CooperativeMatrixKHR" ) {
747672 ResultType = ConvertSPVCooperativeMatrixType (RD);
748673 break ;
749674 } else if (RD && RD->getQualifiedNameAsString () ==
0 commit comments