Skip to content

Commit aebf448

Browse files
authored
Merge branch 'main' into liyang/fix_dot3d_on_lnl
2 parents 8b1f41c + 92dd27c commit aebf448

File tree

45 files changed

+1394
-476
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1394
-476
lines changed

.github/pins/ipex.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
cd132db4e11fbf799a2d3ed2afea100a4afd4efd
1+
15ef7db18b0a50101b41d9c78780d35ea7937ffc
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
name: Build on Windows
2+
3+
on:
4+
workflow_dispatch:
5+
6+
pull_request:
7+
branches:
8+
- main
9+
- release/**
10+
push:
11+
branches:
12+
- main
13+
- release/**
14+
15+
permissions: read-all
16+
17+
env:
18+
NEW_WORKSPACE: C:\gh${{ github.run_id }}
19+
20+
jobs:
21+
build:
22+
name: Build
23+
runs-on: avc336
24+
steps:
25+
- name: Enable long paths
26+
run: |
27+
git config --system core.longPaths true
28+
29+
- name: Checkout repository
30+
uses: actions/checkout@v4
31+
32+
- name: Install Python
33+
uses: actions/setup-python@v5
34+
with:
35+
python-version: '3.9'
36+
37+
# Copy workspace to a temporary location with a shorter name.
38+
- name: Copy workspace
39+
run: |
40+
Copy-Item -Path ${{ github.workspace }} -Destination ${{ env.NEW_WORKSPACE }} -Recurse
41+
42+
# We need ninja >= 1.12.0 to support long names on Windows. At the moment there is no required
43+
# version in pypi, so instead of installing ninja with pip we use a preinstalled 1.12.1 on the
44+
# runner.
45+
- name: Build Triton
46+
run: |
47+
cd ${{ env.NEW_WORKSPACE }}
48+
cd python
49+
pip install -U wheel pybind11 certifi cython cmake
50+
python -m certifi
51+
pip install --no-build-isolation '.[build]'
52+
53+
- name: Clean
54+
if: ${{ always() }}
55+
run: |
56+
Remove-Item -LiteralPath ${{ env.NEW_WORKSPACE }} -Force -Recurse -ErrorAction Ignore

.github/workflows/build-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@ on:
4343
pull_request:
4444
branches:
4545
- main
46+
- release/**
4647
push:
4748
branches:
4849
- main
50+
- release/**
4951

5052
permissions: read-all
5153

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ jobs:
404404
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
405405
fi
406406
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
407+
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py
407408
cd python/test/unit
408409
pytest --capture=tee-sys -rfs -n 16 language runtime \
409410
--ignore=language/test_line_info.py \

.github/workflows/integration-tests.yml.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ jobs:
402402
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
403403
fi
404404
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
405+
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py
405406
cd python/test/unit
406407
pytest --capture=tee-sys -rfs -n 16 language runtime \
407408
--ignore=language/test_line_info.py \

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1142,7 +1142,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
11421142
bool isHopper() const;
11431143

11441144
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> shape,
1145-
int bitwidth, int opIdx) const;
1145+
int bitwidth, int kWidth,
1146+
int opIdx) const;
11461147
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
11471148

11481149
bool supportReduction() const {

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
2828
"TRITON_DISABLE_LINE_INFO",
2929
"TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE",
3030
"TRITON_ENABLE_LLVM_DEBUG",
31+
"TRITON_HIP_STREAM_PREFETCH",
3132
"TRITON_LLVM_DEBUG_ONLY",
3233
"USE_IR_LOC",
3334
"NVPTX_ENABLE_DUMP",

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,7 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
953953
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
954954
if (mma.isAmpere() || mma.isHopper()) {
955955
auto bitwidth = getPointeeType(eltTy).getIntOrFloatBitWidth();
956-
auto rep = mma.getRepForOperand(shape, bitwidth, idx);
956+
auto rep = mma.getRepForOperand(shape, bitwidth, kWidth, idx);
957957
auto sizePerThread = getSizePerThread();
958958
auto elemsPerKRep = mma.isHopper() ? (kWidth * 2) : (32 / bitwidth * 2);
959959
if (rank == 3)
@@ -2018,14 +2018,18 @@ NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
20182018

20192019
SmallVector<int64_t>
20202020
NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
2021-
int opIdx) const {
2021+
int kWidth, int opIdx) const {
20222022
auto rank = shape.size();
20232023
auto warpsPerCTA = getWarpsPerCTA();
20242024

20252025
// {batch, m, n, k}
20262026
// Hopper path never uses the n value, since this method is only invoked
20272027
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
2028-
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
2028+
// TODO: rep per operand is not accurate for Hopper. It is currently done that
2029+
// way to allow us to get the correct total number of elements. this will be
2030+
// fixed when moving to linear layout.
2031+
SmallVector<int> shapePerWarp = {
2032+
1, 16, 8, isHopper() ? 4 * 2 * kWidth : 4 * 64 / bitwidth};
20292033
int numRepBatch =
20302034
rank == 3
20312035
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))

lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,7 @@ void setUseAccFlag(Operation *op, Value useAcc) {
3838
}
3939

4040
bool isConstantZeroTensor(Value v) {
41-
auto constOp = v.getDefiningOp<arith::ConstantOp>();
42-
if (!constOp)
43-
return false;
44-
auto splat = mlir::dyn_cast<SplatElementsAttr>(constOp.getValue());
45-
if (!splat)
46-
return false;
47-
return splat.getSplatValue<FloatAttr>().getValue().convertToFloat() == 0.0f;
41+
return (matchPattern(v, m_Zero()) || matchPattern(v, m_AnyZeroFloat()));
4842
}
4943

5044
std::optional<std::pair<Operation *, int>> findZeroInitOp(Value accUse,

python/setup.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,13 @@ def find_visual_studio(version_ranges):
119119
for version_range in version_ranges:
120120
command = [
121121
str(vswhere), "-version", version_range, "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
122-
"-property", "installationPath", "-prerelease"
122+
"-products", "*", "-property", "installationPath", "-prerelease"
123123
]
124124

125125
try:
126126
output = subprocess.check_output(command, text=True).strip()
127127
if output:
128-
return output
128+
return output.split("\n")[0]
129129
except subprocess.CalledProcessError:
130130
continue
131131

@@ -146,6 +146,13 @@ def set_env_vars(vs_path, arch="x64"):
146146
os.environ[var] = value
147147

148148

149+
def initialize_visual_studio_env(version_ranges, arch="x64"):
150+
vs_path = find_visual_studio(version_ranges)
151+
if not vs_path:
152+
raise EnvironmentError("Visual Studio not found in specified version ranges.")
153+
set_env_vars(vs_path, arch)
154+
155+
149156
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
150157
def check_env_flag(name: str, default: str = "") -> bool:
151158
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
@@ -447,10 +454,7 @@ def build_extension(self, ext):
447454
lit_dir = shutil.which('lit')
448455
ninja_dir = shutil.which('ninja')
449456
if platform.system() == "Windows":
450-
vs_path = find_visual_studio(["[17.0,18.0)", "[16.0,17.0)"])
451-
env = set_env_vars(vs_path)
452-
if not vs_path:
453-
raise EnvironmentError("Visual Studio 2019 or 2022 not found.")
457+
initialize_visual_studio_env(["[17.0,18.0)", "[16.0,17.0)"])
454458
# lit is used by the test suite
455459
thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()])
456460
thirdparty_cmake_args += self.get_pybind11_cmake_args()

0 commit comments

Comments
 (0)