Skip to content
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ elseif(FLAGTREE_BACKEND STREQUAL "mthreads")
set(CMAKE_CXX_COMPILER clang++)
set(ENV{FLAGTREE_PLUGIN} $ENV{FLAGTREE_BACKEND})
elseif(FLAGTREE_BACKEND STREQUAL "aipu")
set(CMAKE_C_COMPILER clang-16)
set(CMAKE_CXX_COMPILER clang++-16)
set(CMAKE_C_COMPILER clang-15)
set(CMAKE_CXX_COMPILER clang++-15)
add_definitions(-D__NVIDIA__)
add_definitions(-D__AMD__)
elseif(FLAGTREE_BACKEND STREQUAL "tsingmicro")
Expand Down
4 changes: 2 additions & 2 deletions python/setup_tools/setup_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ def check_env(env_val):

# aipu
cache.store(
file="llvm-a66376b0-ubuntu-x64",
file="llvm-a66376b0-ubuntu-arm64",
condition=("aipu" == flagtree_backend),
url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-a66376b0-ubuntu-x64.tar.gz",
url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-a66376b0-ubuntu-arm64.tar.gz",
pre_hock=lambda: check_env('LLVM_SYSPATH'),
post_hock=set_llvm_env,
)
Expand Down
250 changes: 225 additions & 25 deletions python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,218 @@
import os
import re
import hashlib
import subprocess
import sysconfig
from abc import ABCMeta, abstractmethod

from abc import ABCMeta, abstractmethod, abstractclassmethod
from dataclasses import dataclass
from typing import Dict, Union
from typing import Dict, List, Tuple, Union
from types import ModuleType

# Table that associates strings to AttrsDescriptor (sub)classes.
# In this way we can dynamically select the correct class
# constructor
_descriptor_table = {}


def register_descriptor(cls):
"""
Register a descriptor into the descriptor table
"""
_descriptor_table[cls.__name__] = cls
return cls


@register_descriptor
class AttrsDescriptor:
"""
This class handles compile-time properties for specific function parameters.

Different backends can add more properties to the common ones. The class
contains two fields:

`arg_properties`: a dictionary containing the different compile-time properties for different
parameters. I.e., the dictionary is a map from property names to parameter indices
{
"prop0": (0, 2, 3)
"prop1": (0, 4, 5)
}
Different backends might need different properties on those paraemters to enable
specific optimizations. The common compile time properties contained in this class
are :
- "tt.divisibility", i.e., is the given parameter divisible by 16
- "tt.equal_to_1", i.e., is the given parameter an integer constant 1

`property_values`: a dictionary containing the value of the different compile-time properties, like:
{
"prop0": val0
"prop1": val1
}

`constant_properties`: a set containing the properties that can be used to determine if a parameter is constant

"""
__slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties')

def __init__(self, params=None, values=None):
"""
Initialize the compile-time properties

We can initialize the AttrsDescriptor class by passing the list of params
of the function and their `values`. The function will try to apply the properties
to the values and save the parameters in the `arg_properties` list. If we don't pass
either the `params` or the `values` we should initialize the class via an alternative method
(see `from_dict` or `from_hints`)
"""
# Default initialization
self.arg_properties = {}
self.property_values = {}
self.constant_properties = set()

self._add_common_properties(params, values)
self._add_backend_properties(params, values)
self._init_slots()

def _add_common_properties(self, params, values):
""" Add common compile-time properties """
self.property_values["tt.divisibility"] = 16
self.property_values["tt.equal_to"] = 1
self.constant_properties.add("tt.equal_to")

if (params is None) or (values is None):
return

# Compile properties deduction
assert (len(params) == len(values))

# Divisibility property
self.arg_properties["tt.divisibility"] = [
param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg)
and not param.do_not_specialize and not param.do_not_specialize_on_alignment
]

# Equal to 1 property
self.arg_properties["tt.equal_to"] = [
param.num
for param, arg in zip(params, values)
if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize
]

def _add_backend_properties(self, params=None, values=None):
""" This method is for different subclasses to implement their own compile-time properties """
pass

def _init_slots(self):
""" Initialize the slots of this class """
for name, val in self.arg_properties.items():
setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val)

def get_fn_attrs(self) -> Dict:
"""
Get the function attributes as a dictionary.

The returned dictionary will look like :
{
"arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]}
"arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]}
}
"""
attrs = {}
for prop_name, arg_set in self.arg_properties.items():
prop_val = self.property_values[prop_name]
for arg in arg_set:
attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)]
return attrs

def get_constants(self) -> Dict:
""" Return a mapping of constant parameters to their values """
constants = {}
for prop_name in self.constant_properties:
for p in self.arg_properties.get(prop_name, []):
constants[p] = self.property_values[prop_name]
return constants

def filter_out_constants(self):
""" Return the same object, without properties marked as constants"""
import copy
c = copy.deepcopy(self)
for prop_name in c.constant_properties:
c.arg_properties.pop(prop_name, None)
c.property_values.pop(prop_name, None)
c.constant_properties = {}
return c

def hash(self):
values = [sorted(self.arg_properties.values())]
values += [sorted(self.property_values.values())]
values += [sorted(self.constant_properties)]
key = str(values)
return hashlib.sha256(key.encode("utf-8")).hexdigest()

def to_dict(self):
"""
Store the fields of this class in a serializable dictionary
"""
# We need to only store the `arg_properties` field. To initialize the
# other fields we relay on the class type. We store it as a string in
# the dictionary so that we can use it to invoke the appropriate
# (sub)class constructor in the `from_dict` method.
return {"arg_properties": self.arg_properties, "cls": type(self).__name__}

@staticmethod
def from_dict(data):
"""
Create the object from a serializable dictionary
"""
attrs_descriptor = _descriptor_table[data["cls"]]()
for prop_name, param_ids in data["arg_properties"].items():
attrs_descriptor.arg_properties[prop_name] = param_ids
attrs_descriptor._init_slots()
return attrs_descriptor

@classmethod
def from_hints(cls, hints: List[Tuple[int, int]]):
"""
Create the class from a set of hints that are passed in.

Instead of deducing the properties from a list of paramaters and values,
the user can pass in a list of `hints=[(param_index, val)]` and if `val`
matches one of the values of the properties (e.g., `prop_val[prop0]`),
then we insert `param_index` into the correct list (e.g., in
`arg_properties[prop0]`)
"""
attrs_descriptor = cls()
for prop_name, prop_val in attrs_descriptor.property_values.items():
attrs_descriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
attrs_descriptor._init_slots()
return attrs_descriptor

@staticmethod
def is_divisible_by_16(x):
""" Return if the argument is a multiple of 16"""
if hasattr(x, "data_ptr"):
return x.data_ptr() % 16 == 0
elif isinstance(x, int):
return x % 16 == 0
if x is None:
return True
return False

@staticmethod
def is_equal_to_1(x):
""" Return if the argument is a constant 1"""
return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False

@staticmethod
def get_property_key(val, align):
if align and AttrsDescriptor.is_divisible_by_16(val):
return "D"
if AttrsDescriptor.is_equal_to_1(val):
return "1"
return "N"

def __repr__(self):
return f"AttrsDescriptor.from_dict({self.to_dict()!r})"


@dataclass(frozen=True)
class GPUTarget(object):
Expand All @@ -25,23 +231,22 @@ def __init__(self, target: GPUTarget) -> None:

@staticmethod
def _path_to_binary(binary: str):
binary += sysconfig.get_config_var("EXE")
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
paths = [
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
os.path.join(base_dir, "third_party", "cuda", "bin", binary),
]
for path in paths:
if os.path.exists(path) and os.path.isfile(path):
result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
for p in paths:
bin = p.split(" ")[0]
if os.path.exists(bin) and os.path.isfile(bin):
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
if result is not None:
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
if version is not None:
return path, version.group(1)
return p, version.group(1)
raise RuntimeError(f"Cannot find {binary}")

@classmethod
@abstractmethod
@abstractclassmethod
def supports_target(target: GPUTarget):
raise NotImplementedError

Expand Down Expand Up @@ -84,21 +289,16 @@ def get_module_map(self) -> Dict[str, ModuleType]:
"""
raise NotImplementedError

@staticmethod
def parse_attr(desc):
assert isinstance(desc, str)
ret = []
if "D" in desc:
ret += [["tt.divisibility", 16]]
return ret
def get_attrs_descriptor(self, params, args):
"""
Return an attribute descriptor: given a set of parameters and arguments
the descriptor stores a set of compile time properties that can improve code
generation. Different backends might benefit from different properties
"""
return AttrsDescriptor(params, args)

@staticmethod
def get_arg_specialization(arg, ty, **kwargs):
def compute_spec_key(self, arg, align):
"""
Return a string unique to each possible specialization of the argument
Return the ascii key for a given argument with a given set of properties
"""
if ty == "int" and arg % 16 == 0 and kwargs.get("align", False):
return "D"
if ty == "tensor" and arg.data_ptr() % 16 == 0 and kwargs.get("align", False):
return "D"
return ""
return AttrsDescriptor.get_property_key(arg, align)
8 changes: 4 additions & 4 deletions third_party/aipu/backend/codegen.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
import tvm
from tvm import tir, ir
from tvm.script.parser import tir as T
from tvm.compass.dsl import BuildManager, script as S
#import tvm
#from tvm import tir, ir
#from tvm.script.parser import tir as T
#from tvm.compass.dsl import BuildManager, script as S
from mlir import ir as mlir_ir
from mlir.dialects import func

Expand Down
Loading