Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# This is not the build system, just a helper to run common development commands.
# Make sure to first initialize the build system with:
# make dev-install

PYTHON := python
BUILD_DIR := $(shell cd python; $(PYTHON) -c 'from build_helpers import get_cmake_dir; print(get_cmake_dir())')
TRITON_OPT := $(BUILD_DIR)/bin/triton-opt

.PHONY: all
all:
ninja -C $(BUILD_DIR)

.PHONY: triton-opt
triton-opt:
ninja -C $(BUILD_DIR) triton-opt

.PHONY: test-lit
test-lit:
ninja -C $(BUILD_DIR) check-triton-lit-tests

.PHONY: test-cpp
test-cpp:
ninja -C $(BUILD_DIR) check-triton-unit-tests

.PHONY: test-python
test-python: all
$(PYTHON) -m pytest python/test/unit

.PHONY: test
test: test-lit test-cpp test-python

.PHONY: dev-install
dev-install:
# build-time dependencies
$(PYTHON) -m pip install ninja cmake wheel pybind11
# test dependencies
$(PYTHON) -m pip install scipy numpy torch pytest lit pandas matplotlib
$(PYTHON) -m pip install -e python --no-build-isolation -v

.PHONY: golden-samples
golden-samples: triton-opt
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \
-o test/TritonGPU/samples/simulated-grouped-gemm.mlir
$(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="\bmodule" \
-o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir
33 changes: 6 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,36 +130,15 @@ There currently isn't a turnkey way to run all the Triton tests, but you can
follow the following recipe.

```shell
# One-time setup. Note we have to reinstall local Triton because torch
# One-time setup. Note this will reinstall local Triton because torch
# overwrites it with the public version.
$ pip install scipy numpy torch pytest lit pandas matplotlib && pip install -e python
$ make dev-install

# Run Python tests using your local GPU.
$ python3 -m pytest python/test/unit
# To run all tests (requires a GPU)
$ make test

# Move to builddir. Fill in <...> with the full path, e.g.
# `cmake.linux-x86_64-cpython-3.11`.
$ cd python/build/cmake<...>

# Run C++ unit tests.
$ ctest -j32

# Run lit tests.
$ lit test
```

You may find it helpful to make a symlink to the builddir and tell your local
git to ignore it.

```shell
$ ln -s python/build/cmake<...> build
$ echo build >> .git/info/exclude
```

Then you can e.g. rebuild and run lit with the following command.

```shell
$ ninja -C build && ( cd build ; lit test )
# Or, to run tests without a gpu
$ make test-cpp test-lit
```

# Tips for hacking
Expand Down
17 changes: 17 additions & 0 deletions python/build_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os
import sysconfig
import sys
from pathlib import Path


def get_base_dir():
return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))


def get_cmake_dir():
plat_name = sysconfig.get_platform()
python_version = sysconfig.get_python_version()
dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}"
cmake_dir = Path(get_base_dir()) / "python" / "build" / dir_name
cmake_dir.mkdir(parents=True, exist_ok=True)
return cmake_dir
15 changes: 2 additions & 13 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

import pybind11

from build_helpers import get_base_dir, get_cmake_dir


@dataclass
class Backend:
Expand Down Expand Up @@ -345,19 +347,6 @@ def download_and_copy(name, src_path, dst_path, variable, version, url_func):
# ---- cmake extension ----


def get_base_dir():
return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))


def get_cmake_dir():
plat_name = sysconfig.get_platform()
python_version = sysconfig.get_python_version()
dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}"
cmake_dir = Path(get_base_dir()) / "python" / "build" / dir_name
cmake_dir.mkdir(parents=True, exist_ok=True)
return cmake_dir


class CMakeClean(clean):

def initialize_options(self):
Expand Down
24 changes: 0 additions & 24 deletions python/triton/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,27 +55,3 @@ def _impl(current, path):
else:
ret = dict()
return list(ret.keys())


def parse_list_string(s):
s = s.strip()
if s.startswith('[') and s.endswith(']'):
s = s[1:-1]
result = []
current = ''
depth = 0
for c in s:
if c == '[':
depth += 1
current += c
elif c == ']':
depth -= 1
current += c
elif c == ',' and depth == 0:
result.append(current.strip())
current = ''
else:
current += c
if current.strip():
result.append(current.strip())
return result
2 changes: 1 addition & 1 deletion python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,7 +1340,7 @@ def ret(self, node: ast.Call):


def ast_to_ttir(fn, src, context, options, codegen_fns, module_map):
arg_types = [str_to_ty(ty) for ty in src.signature.values()]
arg_types = list(map(str_to_ty, src.signature.values()))
prototype = ASTFunction([], arg_types, src.constants, src.attrs)
file_name, begin_line = get_jit_fn_file_line(fn)
# query function representation
Expand Down
12 changes: 4 additions & 8 deletions python/triton/language/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""isort:skip_file"""
# Import order is significant here.

from .._utils import parse_list_string
from . import math
from . import extra
from .standard import (
Expand Down Expand Up @@ -268,8 +267,10 @@


def str_to_ty(name):
if name == "none":
return None
from builtins import tuple

if isinstance(name, tuple):
return tuple_type([str_to_ty(x) for x in name])

if name[0] == "*":
name = name[1:]
Expand All @@ -280,11 +281,6 @@ def str_to_ty(name):
ty = str_to_ty(name)
return pointer_type(element_ty=ty, const=const)

if name[0] == "[":
names = parse_list_string(name)
tys = [str_to_ty(x) for x in names]
return tuple_type(types=tys)

if name == "nvTmaDesc":
return nv_tma_desc_type()

Expand Down
13 changes: 3 additions & 10 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ def specialize_impl(arg, specialize_extra, is_const=False, specialize_value=True
return ("nvTmaDesc", None)
elif isinstance(arg, tuple):
spec = [specialize_impl(x, specialize_extra) for x in arg]
tys = [x[0] for x in spec]
keys = [x[1] for x in spec]
tys = tuple([x[0] for x in spec])
keys = tuple([x[1] for x in spec])
return (tys, keys)
else:
# dtypes are hashable so we can memoize this mapping:
Expand Down Expand Up @@ -515,13 +515,6 @@ def create_binder(self):
]
return {}, target, backend, binder

def _join_signature(self, obj):
if isinstance(obj, list):
inner = ",".join(self._join_signature(x) for x in obj)
return f"[{inner}]"
else:
return str(obj)

def run(self, *args, grid, warmup, **kwargs):
kwargs["debug"] = kwargs.get("debug", self.debug) or os.environ.get("TRITON_DEBUG", "0") == "1"

Expand All @@ -547,7 +540,7 @@ def run(self, *args, grid, warmup, **kwargs):
# signature
sigkeys = [x.name for x in self.params]
sigvals = [x[0] for x in specialization]
signature = {k: self._join_signature(v) for (k, v) in zip(sigkeys, sigvals)}
signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
# check arguments
assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
assert "device" not in kwargs, "device option is deprecated; current device will be used"
Expand Down
5 changes: 1 addition & 4 deletions test/TritonGPU/samples/descriptor-matmul-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
// CHECK: #[[$ATTR_3:.+]] = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}>
// CHECK: #[[$ATTR_4:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>
// CHECK: #[[$ATTR_5:.+]] = #ttg.shared_memory
// To regenerate this test case, run the command
// triton-opt test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
// utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="\bmodule" \
// -o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir
// To regenerate this test case, run `make golden-samples` in the triton root directory
// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=51 %s
// CHECK-LABEL: tt.func public @matmul_kernel_with_descriptors(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_4:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_5:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
Expand Down
5 changes: 1 addition & 4 deletions test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
// To regenerate this test case, run the command
// triton-opt test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
// utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="\bmodule" \
// -o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir
// To regenerate this test case, run `make golden-samples` in the triton root directory
// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=51 %s
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @matmul_kernel_with_descriptors(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
Expand Down
5 changes: 1 addition & 4 deletions test/TritonGPU/samples/simulated-grouped-gemm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
// CHECK: #[[$ATTR_2:.+]] = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}>
// CHECK: #[[$ATTR_3:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>
// CHECK: #[[$ATTR_4:.+]] = #ttg.shared_memory
// To regenerate this test case, run the command
// triton-opt test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
// utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \
// -o test/TritonGPU/samples/simulated-grouped-gemm.mlir
// To regenerate this test case, run `make golden-samples` in the triton root directory
// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s
// CHECK-LABEL: tt.func public @matmul_kernel_descriptor_persistent(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_4:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_5:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
Expand Down
5 changes: 1 addition & 4 deletions test/TritonGPU/samples/simulated-grouped-gemm.mlir.in
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
// To regenerate this test case, run the command
// triton-opt test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
// utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \
// -o test/TritonGPU/samples/simulated-grouped-gemm.mlir
// To regenerate this test case, run `make golden-samples` in the triton root directory
// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @matmul_kernel_descriptor_persistent(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
Expand Down
39 changes: 19 additions & 20 deletions third_party/amd/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from triton.runtime.cache import get_cache_manager
from triton.backends.compiler import GPUTarget
from triton.backends.driver import GPUDriver
from triton._utils import parse_list_string

dirname = os.path.dirname(os.path.realpath(__file__))
include_dir = [os.path.join(dirname, "include")]
Expand Down Expand Up @@ -188,30 +187,30 @@ def ty_to_cpp(ty):

def make_launcher(constants, signature, warp_size):

def _serialize_signature(sig):
if isinstance(sig, tuple):
return ','.join(map(_serialize_signature, sig))
return sig

def _extracted_type(ty):
if ty == "constexpr":
return "PyObject*"
if isinstance(ty, tuple):
val = ','.join(map(_extracted_type, ty))
return f"[{val}]"
if ty[0] == '*':
return "PyObject*"
if ty[0] == '[':
if ty == "[]":
return "[]"
tys = parse_list_string(ty)
val = ','.join(map(_extracted_type, tys))
return f"[{val}]"
if ty in ("constexpr"):
return "PyObject*"
return ty_to_cpp(ty)

def format_of(ty):
if ty == "hipDeviceptr_t":
return "O"
if ty[0] == "[":
if ty == "[]":
return "()"
tys = parse_list_string(ty)
val = ''.join(map(format_of, tys))
if isinstance(ty, tuple):
val = ''.join(map(format_of, ty))
return f"({val})"
if ty[0] == '*':
return "O"
if ty in ("constexpr"):
return "O"
return {
"PyObject*": "O",
"float": "f",
"double": "d",
"long": "l",
Expand All @@ -223,11 +222,11 @@ def format_of(ty):
"uint16_t": "H",
"uint32_t": "I",
"uint64_t": "K",
}[ty]
}[ty_to_cpp(ty)]

args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
args_format = ''.join([format_of(ty) for ty in signature.values()])
format = "iiiKKOOOO" + args_format
signature = ','.join(signature.values()).replace('[', '').replace(']', '')
signature = ','.join(map(_serialize_signature, signature.values()))
signature = list(filter(bool, signature.split(',')))
signature = {i: s for i, s in enumerate(signature)}
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
Expand Down
Loading
Loading