Skip to content

Commit 8c00e2d

Browse files
committed
use ptr type in tptr, add opconverter and typeconverter for memref<32x!ptr<default.....>>
1 parent 454973f commit 8c00e2d

File tree

5 files changed

+634
-211
lines changed

5 files changed

+634
-211
lines changed

backend/compiler.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _get_sanitizer_type():
4242
if sanitizer_type != "" and sanitizer_type != "asan" and sanitizer_type != "tsan":
4343
# throw error
4444
raise Exception(f"TRITON_SHARED_SANITIZER_TYPE {sanitizer_type} is invalid.")
45-
45+
4646
return sanitizer_type
4747

4848
def _ttir_to_ttsharedir(mod):
@@ -92,6 +92,7 @@ def _ttsharedir_to_llir(ttsharedir: str):
9292
triton_shared.to_llir.add_convert_linalg_to_loops(pm)
9393
triton_shared.to_llir.add_expand_strided_metadata(pm)
9494
triton_shared.to_llir.add_convert_scf_to_cf(pm)
95+
triton_shared.to_llir.add_convert_tptr_to_llvm(pm)
9596
triton_shared.to_llir.add_convert_arith_to_llvm(pm)
9697
triton_shared.to_llir.add_convert_math_to_llvm(pm)
9798
triton_shared.to_llir.add_convert_complex_to_llvm(pm)
@@ -100,9 +101,6 @@ def _ttsharedir_to_llir(ttsharedir: str):
100101
triton_shared.to_llir.add_memref_expand(pm)
101102
triton_shared.to_llir.add_finalize_memref_to_llvm(pm)
102103
triton_shared.to_llir.add_convert_func_to_llvm(pm)
103-
# triton_shared.debug.enable_mlir_debug("tptr-to-llvm")
104-
triton_shared.to_llir.add_convert_tptr_to_llvm(pm)
105-
106104
triton_shared.to_llir.add_convert_cf_to_llvm(pm)
107105
triton_shared.to_llir.add_lower_affine(pm)
108106
triton_shared.to_llir.add_convert_arith_to_llvm(pm)
@@ -142,16 +140,16 @@ def _llir_to_bin(llir: str, metadata):
142140
# using a sanitizer
143141
# invoke pass to append sanitizer attributes
144142
instrumented_src_path = os.path.join(tmpdir, "kernel-instrumented.ll")
145-
143+
146144
opt_path = _get_llvm_bin_path("opt")
147145
top_level_triton_path = os.path.dirname(triton.__file__)
148146
sanitizer_attributes_pass_path = str(next(Path(top_level_triton_path).rglob("libSanitizerAttributes.so"), None))
149147

150148
if not sanitizer_attributes_pass_path:
151149
raise Exception(f"libSanitizerAttributes.so does not exist.")
152150

153-
subprocess.check_call([opt_path, "-load-pass-plugin", sanitizer_attributes_pass_path,
154-
"-passes=sanitizer-attributes", f"-sanitizer-type={sanitizer_type}", "-S", src_path,
151+
subprocess.check_call([opt_path, "-load-pass-plugin", sanitizer_attributes_pass_path,
152+
"-passes=sanitizer-attributes", f"-sanitizer-type={sanitizer_type}", "-S", src_path,
155153
"-o", instrumented_src_path])
156154

157155
# compile to object file
@@ -163,12 +161,12 @@ def _llir_to_bin(llir: str, metadata):
163161
subprocess_args.extend(["-g", "-fsanitize=address", "-mllvm", "-asan-stack=0"])
164162
elif sanitizer_type == "tsan":
165163
subprocess_args.extend(["-g", "-fsanitize=thread"])
166-
164+
167165
subprocess.check_call(subprocess_args)
168166
else:
169167
llc_path = _get_llvm_bin_path("llc")
170168
subprocess.check_call([llc_path, src_path, "-filetype=obj", "-relocation-model=pic", "-o", dst_path])
171-
169+
172170
return Path(dst_path).read_bytes()
173171

174172

@@ -262,11 +260,11 @@ def add_stages(self, stages, options, language):
262260
stages["llir"] = lambda src, metadata: _optimize_llir(_ttsharedir_to_llir(src))
263261
stages["obj"] = lambda src, metadata: _llir_to_bin(src, metadata)
264262

265-
266263
@functools.lru_cache()
267264
def hash(self):
268265
return self.target
269266

270267
# The CPU backend does not use any extra python modules, return an empty dictionary
271268
def get_module_map(self) -> Dict[str, ModuleType]:
272269
return {}
270+

include/triton-shared/Dialect/TPtr/IR/TPtrDialect.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ def TPTR_TypeOffsetOp : TPTR_Op<"type_offset", [ConstantLike, Pure]> {
116116
attr-dict $baseType custom<IntType>(type($result))
117117
}];
118118
let hasFolder = 1;
119+
let extraClassDeclaration = [{
120+
/// Returns the type offset according to `layout`. If `layout` is `nullopt`
121+
/// the nearest layout the op will be used for the computation.
122+
llvm::TypeSize getTypeSize(std::optional<DataLayout> layout = std::nullopt);
123+
}];
119124
}
120125

121126
def TPTR_FromMemrefOp : TPTR_Op<"from_memref", [Pure]> {
@@ -215,3 +220,4 @@ def TTPTR_StoreOp : TPTR_Op<"store", [
215220
}
216221

217222
#endif // TPTR_DIALECT
223+

0 commit comments

Comments
 (0)