| 
13 | 13 | #include "clang/Driver/Job.h"  | 
14 | 14 | #include "llvm/ADT/StringSwitch.h"  | 
15 | 15 | #include "llvm/TargetParser/Triple.h"  | 
 | 16 | +#include <regex>  | 
16 | 17 | 
 
  | 
17 | 18 | using namespace clang::driver;  | 
18 | 19 | using namespace clang::driver::tools;  | 
@@ -173,6 +174,39 @@ bool isLegalValidatorVersion(StringRef ValVersionStr, const Driver &D) {  | 
173 | 174 |   return true;  | 
174 | 175 | }  | 
175 | 176 | 
 
  | 
 | 177 | +std::string getSpirvExtArg(ArrayRef<std::string> SpvExtensionArgs) {  | 
 | 178 | +  if (SpvExtensionArgs.empty()) {  | 
 | 179 | +    return "-spirv-ext=all";  | 
 | 180 | +  }  | 
 | 181 | + | 
 | 182 | +  std::string LlvmOption =  | 
 | 183 | +      (Twine("-spirv-ext=+") + SpvExtensionArgs.front()).str();  | 
 | 184 | +  SpvExtensionArgs = SpvExtensionArgs.slice(1);  | 
 | 185 | +  for (auto Extension : SpvExtensionArgs) {  | 
 | 186 | +    LlvmOption = (Twine(LlvmOption) + ",+" + Extension).str();  | 
 | 187 | +  }  | 
 | 188 | +  return LlvmOption;  | 
 | 189 | +}  | 
 | 190 | + | 
 | 191 | +bool isValidSPIRVExtensionName(const std::string &str) {  | 
 | 192 | +  std::regex pattern("SPV_[a-zA-Z0-9_]+");  | 
 | 193 | +  return std::regex_match(str, pattern);  | 
 | 194 | +}  | 
 | 195 | + | 
 | 196 | +// SPIRV extension names are of the form `SPV_[a-zA-Z0-9_]+`. We want to  | 
 | 197 | +// disallow obviously invalid names to avoid issues when parsing `spirv-ext`.  | 
 | 198 | +bool checkExtensionArgsAreValid(ArrayRef<std::string> SpvExtensionArgs,  | 
 | 199 | +                                const Driver &Driver) {  | 
 | 200 | +  bool AllValid = true;  | 
 | 201 | +  for (auto Extension : SpvExtensionArgs) {  | 
 | 202 | +    if (!isValidSPIRVExtensionName(Extension)) {  | 
 | 203 | +      Driver.Diag(diag::err_drv_invalid_value)  | 
 | 204 | +          << "-fspv_extension" << Extension;  | 
 | 205 | +      AllValid = false;  | 
 | 206 | +    }  | 
 | 207 | +  }  | 
 | 208 | +  return AllValid;  | 
 | 209 | +}  | 
176 | 210 | } // namespace  | 
177 | 211 | 
 
  | 
178 | 212 | void tools::hlsl::Validator::ConstructJob(Compilation &C, const JobAction &JA,  | 
@@ -301,6 +335,17 @@ HLSLToolChain::TranslateArgs(const DerivedArgList &Args, StringRef BoundArch,  | 
301 | 335 |     DAL->append(A);  | 
302 | 336 |   }  | 
303 | 337 | 
 
  | 
 | 338 | +  if (getArch() == llvm::Triple::spirv) {  | 
 | 339 | +    std::vector<std::string> SpvExtensionArgs =  | 
 | 340 | +        Args.getAllArgValues(options::OPT_fspv_extension_EQ);  | 
 | 341 | +    if (checkExtensionArgsAreValid(SpvExtensionArgs, getDriver())) {  | 
 | 342 | +      std::string LlvmOption = getSpirvExtArg(SpvExtensionArgs);  | 
 | 343 | +      DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_mllvm),  | 
 | 344 | +                          LlvmOption);  | 
 | 345 | +    }  | 
 | 346 | +    Args.claimAllArgs(options::OPT_fspv_extension_EQ);  | 
 | 347 | +  }  | 
 | 348 | + | 
304 | 349 |   if (!DAL->hasArg(options::OPT_O_Group)) {  | 
305 | 350 |     DAL->AddJoinedArg(nullptr, Opts.getOption(options::OPT_O), "3");  | 
306 | 351 |   }  | 
 | 
0 commit comments