Skip to content

Commit 72f80b5

Browse files
[Phase 1] full flow
1 parent 8a5c8eb commit 72f80b5

File tree

8 files changed

+444
-0
lines changed

8 files changed

+444
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,4 @@ pythonenv*
7373
/clang/utils/analyzer/projects/*/RefScanBuildResults
7474
# automodapi puts generated documentation files here.
7575
/lldb/docs/python_api/
76+
/install

TTL_MLIR_Integration.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# TTL MLIR Integration
2+
3+
## Project Overview
4+
This project aims to integrate TTL (Template Tiling Library) with MLIR to create an optimized pipeline from C code to TTL-optimized C code. The pipeline includes affine loop tiling and dialect conversions, with a focus on optimizing operations like sigmoid.
5+
6+
## Current Pipeline
7+
```
8+
C code with TTL DSL → MLIR → Optimized MLIR → EmitC → C code
9+
```
10+
11+
## Technical Implementation
12+
13+
### Version Compatibility
14+
- Using LLVM 20 for MLIR pipeline
15+
- Polygeist (C → MLIR) is on LLVM 18
16+
- Solution: Manually removing incompatible parts
17+
- This is a manageable limitation for now
18+
19+
### Type System Integration
20+
- Minor issue with unrealized conversion casts
21+
- Can be fixed with a simple pass if needed
22+
- Not a critical blocker
23+
24+
### TTL Integration Strategy
25+
Two possible approaches:
26+
1. Generate direct function calls to TTL's existing functions
27+
2. Create a TTL dialect (if needed)
28+
- Currently leaning towards function calls for simplicity
29+
- Decision pending based on future requirements
30+
31+
## Next Steps
32+
33+
### 1. Frontend Definition
34+
- Define Polygeist as the frontend
35+
- Its output will feed into TTL optimizer passes (like tiling)
36+
- Currently supporting minimal 2D loops and array access
37+
- Will expand TTL DSL features in the frontend
38+
39+
### 2. Backend Generation
40+
- Develop pipeline to generate TTL-specific code
41+
- Focus on efficient memory operations and tiling
42+
43+
### 3. TTL DSL Development
44+
- Currently minimal: 2D loops and array access
45+
- Will expand based on requirements
46+
- Starting with sigmoid as a test case
47+
48+
### 4. Immediate Focus
49+
- Optimizing sigmoid function
50+
- Using it as a test case for the complete pipeline
51+
- Will use learnings to expand to other operations
52+
53+
## Technical Decisions
54+
- Keeping things simple with function calls rather than new dialect
55+
- Managing version compatibility manually for now
56+
- Type conversion issues are minor and can be addressed if needed
57+
58+
## Current Limitations
59+
1. Version mismatch between Polygeist and MLIR pipeline
60+
2. Minimal TTL DSL features in frontend
61+
3. Focus on sigmoid optimization only
62+
63+
## Future Work
64+
1. Expand TTL DSL features
65+
2. Add more optimization passes
66+
3. Support more complex operations
67+
4. Evaluate need for TTL dialect
68+
5. Consider automating version compatibility fixes

mlir/include/mlir/Transforms/Passes.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class GreedyRewriteConfig;
4646
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
4747
#define GEN_PASS_DECL_TOPOLOGICALSORT
4848
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
49+
#define GEN_PASS_DECL_TTLOPS
50+
#define GEN_PASS_DECL_TTLPIPELINE
51+
#define GEN_PASS_DECL_TTLTOEMITC
4952
#include "mlir/Transforms/Passes.h.inc"
5053

5154
/// Creates an instance of the Canonicalizer pass, configured with default
@@ -65,6 +68,15 @@ createCanonicalizerPass(const GreedyRewriteConfig &config,
6568
ArrayRef<std::string> disabledPatterns = std::nullopt,
6669
ArrayRef<std::string> enabledPatterns = std::nullopt);
6770

71+
/// Creates a TTL ops pass.
72+
std::unique_ptr<Pass> createTTLOpsPass();
73+
74+
/// Creates a TTL pipeline pass that runs multiple passes.
75+
std::unique_ptr<Pass> createTTLPipelinePass();
76+
77+
/// Creates a TTL to emit C pass.
78+
std::unique_ptr<Pass> createTTLToEmitC();
79+
6880
/// Creates a pass to perform control-flow sinking.
6981
std::unique_ptr<Pass> createControlFlowSinkPass();
7082

mlir/include/mlir/Transforms/Passes.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,24 @@ def Canonicalizer : Pass<"canonicalize"> {
5454
] # RewritePassUtils.options;
5555
}
5656

57+
def TTLOps : Pass<"ttl-ops", "ModuleOp"> {
58+
let summary = "Convert TTL operations to MLIR";
59+
let description = [{
60+
This pass converts TTL operations to their MLIR equivalents.
61+
}];
62+
let constructor = "mlir::createTTLOpsPass()";
63+
let dependentDialects = ["func::FuncDialect"];
64+
}
65+
66+
def TTLPipeline : Pass<"ttl-pipeline", "ModuleOp"> {
67+
let summary = "Run a pipeline of TTL passes";
68+
let description = [{
69+
This pass runs a sequence of TTL-related passes in a specific order.
70+
}];
71+
let constructor = "mlir::createTTLPipelinePass()";
72+
let dependentDialects = ["func::FuncDialect"];
73+
}
74+
5775
def ControlFlowSink : Pass<"control-flow-sink"> {
5876
let summary = "Sink operations into conditional blocks";
5977
let description = [{
@@ -586,4 +604,12 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
586604
];
587605
}
588606

607+
def TTLToEmitC : Pass<"ttl-to-emitc", "func::FuncOp"> {
608+
let summary = "Convert TTL operations to EmitC dialect";
609+
let description = [{
610+
This pass converts TTL operations to EmitC dialect for C code generation.
611+
}];
612+
let dependentDialects = ["mlir::emitc::EmitCDialect"];
613+
}
614+
589615
#endif // MLIR_TRANSFORMS_PASSES

mlir/lib/Transforms/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
add_subdirectory(Utils)
22

33
add_mlir_library(MLIRTransforms
4+
TTLOps.cpp
5+
TTLPipeline.cpp
6+
TTLToEmitC.cpp
47
Canonicalizer.cpp
58
CompositePass.cpp
69
ControlFlowSink.cpp

mlir/lib/Transforms/TTLOps.cpp

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
#include "mlir/Pass/Pass.h"
2+
#include "mlir/IR/BuiltinOps.h"
3+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
4+
#include "mlir/Dialect/Func/IR/FuncOps.h"
5+
#include "mlir/IR/Builders.h"
6+
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
7+
#include "mlir/Dialect/Affine/Analysis/Utils.h"
8+
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
9+
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
10+
#include "mlir/Dialect/Affine/LoopUtils.h"
11+
#include "llvm/Support/raw_ostream.h"
12+
13+
using namespace mlir;
14+
using namespace mlir::affine;
15+
16+
namespace {
17+
18+
// Core data structures for analyzing loops and memory accesses
19+
struct LoopInfo {
20+
// Loop bounds and step
21+
int64_t lowerBound;
22+
int64_t upperBound;
23+
int64_t step;
24+
25+
// Memory accesses in this loop
26+
enum class AccessType {
27+
Load,
28+
Store
29+
};
30+
31+
struct MemoryAccess {
32+
Value memref; // The memref being accessed
33+
AffineMap accessMap; // The affine map for the access
34+
AccessType type; // Whether it's a load or store
35+
};
36+
SmallVector<MemoryAccess> accesses;
37+
};
38+
39+
// Helper class to validate loop structures and memory accesses
40+
class LoopValidator {
41+
public:
42+
// Check if a memory access is 2D
43+
static bool is2DAccess(Operation *op) {
44+
AffineMap map;
45+
if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
46+
map = loadOp.getAffineMap();
47+
} else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
48+
map = storeOp.getAffineMap();
49+
} else {
50+
assert(false && "Expected load or store operation");
51+
}
52+
return map.getNumResults() == 2;
53+
}
54+
55+
56+
// Validate loop band and collect information if valid
57+
static std::optional<SmallVector<LoopInfo>> validateAndCollectInfo(ArrayRef<AffineForOp> loops) {
58+
// Check if it's a 2D perfectly nested loop
59+
if (loops.size() != 2 || !affine::isPerfectlyNested(loops)) {
60+
return std::nullopt;
61+
}
62+
63+
SmallVector<LoopInfo> loopInfos;
64+
65+
// Analyze each loop
66+
for (const auto &loop : loops) {
67+
LoopInfo info;
68+
69+
// Get loop bounds and check if they're compile-time constants
70+
auto lowerMap = const_cast<AffineForOp &>(loop).getLowerBoundMap();
71+
auto upperMap = const_cast<AffineForOp &>(loop).getUpperBoundMap();
72+
73+
if (!lowerMap.isConstant() || !upperMap.isConstant()) {
74+
return std::nullopt;
75+
}
76+
77+
info.lowerBound = lowerMap.getSingleConstantResult();
78+
info.upperBound = upperMap.getSingleConstantResult();
79+
info.step = const_cast<AffineForOp &>(loop).getStep().getSExtValue();
80+
81+
// Only collect memory accesses in the innermost loop
82+
if (loop == loops.back()) {
83+
bool all2D = true;
84+
loop->walk([&](Operation *op) {
85+
if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
86+
if (!is2DAccess(op)) {
87+
all2D = false;
88+
return;
89+
}
90+
info.accesses.push_back({loadOp.getMemRef(), loadOp.getAffineMap(), LoopInfo::AccessType::Load});
91+
} else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
92+
if (!is2DAccess(op)) {
93+
all2D = false;
94+
return;
95+
}
96+
info.accesses.push_back({storeOp.getMemRef(), storeOp.getAffineMap(), LoopInfo::AccessType::Store});
97+
}
98+
});
99+
100+
// If not all accesses are 2D, return nullopt
101+
if (!all2D) {
102+
return std::nullopt;
103+
}
104+
}
105+
106+
loopInfos.push_back(info);
107+
}
108+
109+
return loopInfos;
110+
}
111+
};
112+
113+
// Helper function to print loop information
114+
static void printLoopInfo(const SmallVector<LoopInfo> &loopInfos, func::FuncOp funcOp) {
115+
llvm::errs() << "\n=== Band Information ===\n";
116+
117+
// Print loop structure
118+
llvm::errs() << "Loop Structure:\n";
119+
for (size_t i = 0; i < loopInfos.size(); i++) {
120+
const auto &info = loopInfos[i];
121+
llvm::errs() << " Loop " << i << ": [" << info.lowerBound << ", "
122+
<< info.upperBound << ") step " << info.step << "\n";
123+
}
124+
125+
// Print only innermost loop's memory accesses
126+
llvm::errs() << "\nMemory Accesses in Innermost Loop:\n";
127+
const auto &innerLoop = loopInfos.back();
128+
for (const auto &access : innerLoop.accesses) {
129+
llvm::errs() << " " << (access.type == LoopInfo::AccessType::Load ? "Load" : "Store") << " from ";
130+
131+
// Print block argument information
132+
if (auto blockArg = dyn_cast<BlockArgument>(access.memref)) {
133+
llvm::errs() << "<block argument> of type '" << blockArg.getType()
134+
<< "' at index: " << blockArg.getArgNumber()
135+
<< " (arg" << blockArg.getArgNumber() << ")";
136+
} else {
137+
llvm::errs() << access.memref;
138+
}
139+
llvm::errs() << "\n";
140+
llvm::errs() << " Access Map: " << access.accessMap << "\n";
141+
}
142+
llvm::errs() << "================================\n";
143+
}
144+
145+
struct TTLOps : public PassWrapper<TTLOps, OperationPass<ModuleOp>> {
146+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TTLOps)
147+
148+
// Default constructor
149+
TTLOps() = default;
150+
151+
// Copy constructor - needed for pass cloning
152+
TTLOps(const TTLOps &other) : PassWrapper<TTLOps, OperationPass<ModuleOp>>(other) {
153+
// Copy option values
154+
localMemorySize = other.localMemorySize;
155+
loadCost = other.loadCost;
156+
storeCost = other.storeCost;
157+
}
158+
159+
// Pass options
160+
Option<unsigned> localMemorySize{
161+
*this, "local-memory-size",
162+
llvm::cl::desc("Size of local memory in KB (default: 32)"),
163+
llvm::cl::init(32)};
164+
Option<unsigned> loadCost{
165+
*this, "load-cost",
166+
llvm::cl::desc("Cost of a load operation (default: 1)"),
167+
llvm::cl::init(1)};
168+
Option<unsigned> storeCost{
169+
*this, "store-cost",
170+
llvm::cl::desc("Cost of a store operation (default: 1)"),
171+
llvm::cl::init(1)};
172+
173+
StringRef getArgument() const override { return "ttl-ops"; }
174+
StringRef getDescription() const override { return "TTL operations pass"; }
175+
176+
void runOnOperation() override {
177+
ModuleOp module = getOperation();
178+
179+
// Ensure we only have one function in the module
180+
auto funcOps = module.getOps<func::FuncOp>();
181+
assert(std::distance(funcOps.begin(), funcOps.end()) == 1 &&
182+
"Expected exactly one function in the module");
183+
184+
// Find perfect loop nests (bands) in each function
185+
module->walk([&](func::FuncOp funcOp) {
186+
std::vector<SmallVector<AffineForOp, 6>> bands;
187+
mlir::affine::getTileableBands(funcOp, &bands);
188+
189+
// Analyze each band
190+
for (const auto &band : bands) {
191+
// Validate band and collect information
192+
if (auto loopInfos = LoopValidator::validateAndCollectInfo(band)) {
193+
printLoopInfo(*loopInfos, funcOp);
194+
}
195+
}
196+
});
197+
}
198+
};
199+
200+
// Register the pass
201+
void registerTTLOps() {
202+
PassRegistration<TTLOps>();
203+
}
204+
} // end anonymous namespace
205+
206+
namespace mlir {
207+
std::unique_ptr<Pass> createTTLOpsPass() {
208+
return std::make_unique<TTLOps>();
209+
}
210+
} // end namespace mlir

0 commit comments

Comments
 (0)