Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions llvm/include/llvm/SYCLPostLink/SpecializationConstants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//= SpecializationConstants.h - Processing of SYCL Specialization Constants ==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Specialization constants processing consists of lowering and generation
// of new module with spec consts replaced by default values.
//===----------------------------------------------------------------------===//

#ifndef LLVM_SYCL_POST_LINK_SPECIALIZATION_CONSTANTS_H
#define LLVM_SYCL_POST_LINK_SPECIALIZATION_CONSTANTS_H

#include "llvm/ADT/SmallVector.h"
#include "llvm/SYCLLowerIR/SpecConstants.h"
#include "llvm/SYCLPostLink/ModuleSplitter.h"

#include <optional>

namespace llvm {
namespace sycl {

/// Metadata and intrinsics related to SYCL specialization constants are lowered
/// depending on the given
/// \p Mode. If \p Mode is std::nullopt, then no lowering happens.
/// If \p GenerateModuleDescWithDefaultSpecConsts is true, then a generation
/// of new modules with specialization constants replaced by default values
/// happens and the result is written in \p NewModuleDescs.
///
/// \returns Boolean value indicating whether the lowering has changed the input
/// modules.
bool handleSpecializationConstants(
llvm::SmallVectorImpl<module_split::ModuleDesc> &MDs,
std::optional<SpecConstantsPass::HandlingMode> Mode,
llvm::SmallVectorImpl<module_split::ModuleDesc> &NewModuleDescs,
bool GenerateModuleDescWithDefaultSpecConsts);

} // namespace sycl
} // namespace llvm

#endif // LLVM_SYCL_POST_LINK_SPECIALIZATION_CONSTANTS_H
1 change: 1 addition & 0 deletions llvm/lib/SYCLPostLink/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_llvm_component_library(LLVMSYCLPostLink
ComputeModuleRuntimeInfo.cpp
ESIMDPostSplitProcessing.cpp
ModuleSplitter.cpp
SpecializationConstants.cpp

ADDITIONAL_HEADER_DIRS
${LLVM_MAIN_INCLUDE_DIR}/llvm/SYCLPostLink
Expand Down
93 changes: 93 additions & 0 deletions llvm/lib/SYCLPostLink/SpecializationConstants.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
//= SpecializationConstants.cpp - Processing of SYCL Specialization Constants //
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// See comments in the header.
//===----------------------------------------------------------------------===//

#include "llvm/SYCLPostLink/SpecializationConstants.h"

#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/SYCLLowerIR/SpecConstants.h"
#include "llvm/SYCLPostLink/ModuleSplitter.h"
#include "llvm/Transforms/IPO/StripDeadPrototypes.h"

#include <optional>

using namespace llvm;

namespace {

bool lowerSpecConstants(module_split::ModuleDesc &MD,
SpecConstantsPass::HandlingMode Mode) {
ModulePassManager RunSpecConst;
ModuleAnalysisManager MAM;
SpecConstantsPass SCP(Mode);
// Register required analysis.
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
RunSpecConst.addPass(std::move(SCP));

// Perform the specialization constant intrinsics transformation on resulting
// module.
PreservedAnalyses Res = RunSpecConst.run(MD.getModule(), MAM);
MD.Props.SpecConstsMet = !Res.areAllPreserved();
return MD.Props.SpecConstsMet;
}

/// Function generates the copy of the given \p MD where all uses of
/// Specialization constants are replaced by corresponding default values.
/// If the Module in \p MD doesn't contain specialization constants then
/// std::nullopt is returned.
std::optional<module_split::ModuleDesc>
cloneModuleWithSpecConstsReplacedByDefaultValues(
const module_split::ModuleDesc &MD) {
std::optional<module_split::ModuleDesc> NewMD;
if (!checkModuleContainsSpecConsts(MD.getModule()))
return NewMD;

NewMD = MD.clone();
NewMD->setSpecConstantDefault(true);

ModulePassManager MPM;
ModuleAnalysisManager MAM;
SpecConstantsPass SCP(SpecConstantsPass::HandlingMode::default_values);
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
MPM.addPass(std::move(SCP));
MPM.addPass(StripDeadPrototypesPass());

PreservedAnalyses Res = MPM.run(NewMD->getModule(), MAM);
NewMD->Props.SpecConstsMet = !Res.areAllPreserved();
assert(NewMD->Props.SpecConstsMet &&
"SpecConstsMet should be true since the presence of SpecConsts "
"has been checked before the run of the pass");
NewMD->rebuildEntryPoints();
return NewMD;
}

} // namespace

bool llvm::sycl::handleSpecializationConstants(
SmallVectorImpl<module_split::ModuleDesc> &MDs,
std::optional<SpecConstantsPass::HandlingMode> Mode,
SmallVectorImpl<module_split::ModuleDesc> &NewModuleDescs,
bool GenerateModuleDescWithDefaultSpecConsts) {
bool Modified = false;
for (module_split::ModuleDesc &MD : MDs) {
if (GenerateModuleDescWithDefaultSpecConsts)
if (std::optional<module_split::ModuleDesc> NewMD =
cloneModuleWithSpecConstsReplacedByDefaultValues(MD))
NewModuleDescs.push_back(std::move(*NewMD));

if (Mode)
Modified |= lowerSpecConstants(MD, *Mode);
}

return Modified;
}
71 changes: 10 additions & 61 deletions llvm/tools/sycl-post-link/sycl-post-link.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "llvm/SYCLPostLink/ComputeModuleRuntimeInfo.h"
#include "llvm/SYCLPostLink/ESIMDPostSplitProcessing.h"
#include "llvm/SYCLPostLink/ModuleSplitter.h"
#include "llvm/SYCLPostLink/SpecializationConstants.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/InitLLVM.h"
Expand All @@ -48,7 +49,6 @@
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/SystemUtils.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Transforms/IPO/StripDeadPrototypes.h"

#include <algorithm>
#include <memory>
Expand Down Expand Up @@ -429,56 +429,6 @@ void saveDeviceLibModule(
saveModule(OutTables, DeviceLibMD, I, OutputPrefix, "");
}

bool processSpecConstants(module_split::ModuleDesc &MD) {
MD.Props.SpecConstsMet = false;

if (SpecConstLower.getNumOccurrences() == 0)
return false;

ModulePassManager RunSpecConst;
ModuleAnalysisManager MAM;
SpecConstantsPass SCP(SpecConstLower == SC_NATIVE_MODE
? SpecConstantsPass::HandlingMode::native
: SpecConstantsPass::HandlingMode::emulation);
// Register required analysis
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
RunSpecConst.addPass(std::move(SCP));

// Perform the spec constant intrinsics transformation on resulting module
PreservedAnalyses Res = RunSpecConst.run(MD.getModule(), MAM);
MD.Props.SpecConstsMet = !Res.areAllPreserved();
return MD.Props.SpecConstsMet;
}

/// Function generates the copy of the given ModuleDesc where all uses of
/// Specialization Constants are replaced by corresponding default values.
/// If the Module in MD doesn't contain specialization constants then
/// std::nullopt is returned.
std::optional<module_split::ModuleDesc>
processSpecConstantsWithDefaultValues(const module_split::ModuleDesc &MD) {
std::optional<module_split::ModuleDesc> NewModuleDesc;
if (!checkModuleContainsSpecConsts(MD.getModule()))
return NewModuleDesc;

NewModuleDesc = MD.clone();
NewModuleDesc->setSpecConstantDefault(true);

ModulePassManager MPM;
ModuleAnalysisManager MAM;
SpecConstantsPass SCP(SpecConstantsPass::HandlingMode::default_values);
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
MPM.addPass(std::move(SCP));
MPM.addPass(StripDeadPrototypesPass());

PreservedAnalyses Res = MPM.run(NewModuleDesc->getModule(), MAM);
NewModuleDesc->Props.SpecConstsMet = !Res.areAllPreserved();
assert(NewModuleDesc->Props.SpecConstsMet &&
"This property should be true since the presence of SpecConsts "
"has been checked before the run of the pass");
NewModuleDesc->rebuildEntryPoints();
return NewModuleDesc;
}

constexpr int MAX_COLUMNS_IN_FILE_TABLE = 3;

void addTableRow(util::SimpleTable &Table,
Expand Down Expand Up @@ -602,6 +552,12 @@ processInputModule(std::unique_ptr<Module> M, const StringRef OutputPrefix) {
error(toString(std::move(E)));
}

std::optional<SpecConstantsPass::HandlingMode> SCMode;
if (SpecConstLower.getNumOccurrences() > 0)
SCMode = SpecConstLower == SC_NATIVE_MODE
? SpecConstantsPass::HandlingMode::native
: SpecConstantsPass::HandlingMode::emulation;

// It is important that we *DO NOT* preserve all the splits in memory at the
// same time, because it leads to a huge RAM consumption by the tool on bigger
// inputs.
Expand All @@ -624,16 +580,9 @@ processInputModule(std::unique_ptr<Module> M, const StringRef OutputPrefix) {
SmallVector<module_split::ModuleDesc, 2> &MMs = *ModulesOrErr;
assert(MMs.size() && "at least one module is expected after ESIMD split");
SmallVector<module_split::ModuleDesc, 2> MMsWithDefaultSpecConsts;
for (size_t I = 0; I != MMs.size(); ++I) {
if (GenerateDeviceImageWithDefaultSpecConsts) {
std::optional<module_split::ModuleDesc> NewMD =
processSpecConstantsWithDefaultValues(MMs[I]);
if (NewMD)
MMsWithDefaultSpecConsts.push_back(std::move(*NewMD));
}

Modified |= processSpecConstants(MMs[I]);
}
Modified |=
handleSpecializationConstants(MMs, SCMode, MMsWithDefaultSpecConsts,
GenerateDeviceImageWithDefaultSpecConsts);

if (IROutputOnly) {
if (SplitOccurred) {
Expand Down
Loading