Skip to content

Commit 9a32209

Browse files
authored
Setup clang-format action and format the whole codebase. (#349)
This PR does the following: * Clang-Formats all C and C++ files within our repository using the LLVM style. * Creates a github workflow that will verify code formatting on PRs. * Removes symlink of clang-format in favor of a permanent copy. We do not care to keep it up to date with Triton. The likelihood of Triton updating their clang format file is low enough for this to not be a problem. Even if it were updated, it is not essential for us to fully match their style. It is more important to have a permanent copy that keeps the workflows simple.
1 parent b673908 commit 9a32209

File tree

25 files changed

+175
-172
lines changed

25 files changed

+175
-172
lines changed

.clang-format

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
BasedOnStyle: LLVM

.github/workflows/clang-format.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
name: clang-format
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
8+
jobs:
9+
clang-format:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- uses: actions/checkout@v4
13+
- name: Set up Python 3.12
14+
uses: actions/setup-python@v2
15+
with:
16+
python-version: "3.12"
17+
- name: Install dependencies
18+
run: |
19+
pip install clang-format==20.1.8 ripgrep==14.1.0
20+
- name: Running clang-format
21+
run: |
22+
rg . --type cpp --type c --files-with-matches \
23+
| xargs clang-format --dry-run --Werror

.gitignore

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@
33
compile_commands.json
44
build/*
55
.vscode/*
6-
/.clang-format
7-
test_core.py
8-
test_annotations.py
6+
/python/examples/test_core.py
7+
/python/examples/test_annotations.py

CMakeLists.txt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,12 @@ if (TRITON_SHARED_BUILD_CPU_BACKEND)
2727
target_link_libraries(TritonShared PRIVATE Python3::Module pybind11::headers)
2828
endif()
2929

30-
# Add symlinks to selected pytest files and the clang-format setting in triton. The tests are imported into triton-shared’s test folder to
31-
# run under triton-shared's conftest configuration, and the clang-format link ensures consistent code style enforcement across both repositories.
30+
# Add symlinks to selected pytest files in triton. The tests are imported into triton-shared’s test folder to
31+
# run under triton-shared's conftest configuration.
3232
cmake_path(APPEND CMAKE_CURRENT_SOURCE_DIR "python" "examples" "test_core.py" OUTPUT_VARIABLE TRITON_SHARED_TEST_CORE)
3333
cmake_path(APPEND CMAKE_CURRENT_SOURCE_DIR "python" "examples" "test_annotations.py" OUTPUT_VARIABLE TRITON_SHARED_TEST_ANNOTATIONS)
34-
cmake_path(APPEND CMAKE_CURRENT_SOURCE_DIR ".clang-format" OUTPUT_VARIABLE TRITON_SHARED_CLANG_FORMAT_SETTING)
3534
cmake_path(APPEND CMAKE_SOURCE_DIR "python" "test" "unit" "language" "test_core.py" OUTPUT_VARIABLE TRITON_TEST_CORE)
3635
cmake_path(APPEND CMAKE_SOURCE_DIR "python" "test" "unit" "language" "test_annotations.py" OUTPUT_VARIABLE TRITON_TEST_ANNOTATIONS)
37-
cmake_path(APPEND CMAKE_SOURCE_DIR ".clang-format" OUTPUT_VARIABLE TRITON_CLANG_FORMAT_SETTING)
3836

3937
add_symlink(${TRITON_TEST_CORE} ${TRITON_SHARED_TEST_CORE})
4038
add_symlink(${TRITON_TEST_ANNOTATIONS} ${TRITON_SHARED_TEST_ANNOTATIONS})
41-
add_symlink(${TRITON_CLANG_FORMAT_SETTING} ${TRITON_SHARED_CLANG_FORMAT_SETTING})

backend/include/ExecutionEngine/CRunnerUtils.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include "Msan.h"
1717

1818
#ifndef _WIN32
19-
#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
19+
#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
2020
defined(__DragonFly__)
2121
#include <cstdlib>
2222
#else
@@ -37,10 +37,7 @@
3737
#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
3838

3939
namespace {
40-
template <typename V>
41-
void stdSort(uint64_t n, V *p) {
42-
std::sort(p, p + n);
43-
}
40+
template <typename V> void stdSort(uint64_t n, V *p) { std::sort(p, p + n); }
4441

4542
} // namespace
4643

backend/include/ExecutionEngine/CRunnerUtils.h

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,9 @@ constexpr unsigned nextPowerOf2(int n) {
5050
return (n <= 1) ? 1 : (isPowerOf2(n) ? n : (2 * nextPowerOf2((n + 1) / 2)));
5151
}
5252

53-
template <typename T, int Dim, bool IsPowerOf2>
54-
struct Vector1D;
53+
template <typename T, int Dim, bool IsPowerOf2> struct Vector1D;
5554

56-
template <typename T, int Dim>
57-
struct Vector1D<T, Dim, /*IsPowerOf2=*/true> {
55+
template <typename T, int Dim> struct Vector1D<T, Dim, /*IsPowerOf2=*/true> {
5856
Vector1D() {
5957
static_assert(detail::nextPowerOf2(sizeof(T[Dim])) == sizeof(T[Dim]),
6058
"size error");
@@ -68,8 +66,7 @@ struct Vector1D<T, Dim, /*IsPowerOf2=*/true> {
6866

6967
// 1-D vector, padded to the next power of 2 allocation.
7068
// Specialization occurs to avoid zero size arrays (which fail in -Werror).
71-
template <typename T, int Dim>
72-
struct Vector1D<T, Dim, /*IsPowerOf2=*/false> {
69+
template <typename T, int Dim> struct Vector1D<T, Dim, /*IsPowerOf2=*/false> {
7370
Vector1D() {
7471
static_assert(nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), "size error");
7572
static_assert(nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]),
@@ -86,8 +83,7 @@ struct Vector1D<T, Dim, /*IsPowerOf2=*/false> {
8683
} // namespace mlir
8784

8885
// N-D vectors recurse down to 1-D.
89-
template <typename T, int Dim, int... Dims>
90-
struct Vector {
86+
template <typename T, int Dim, int... Dims> struct Vector {
9187
inline Vector<T, Dims...> &operator[](unsigned i) { return vector[i]; }
9288
inline const Vector<T, Dims...> &operator[](unsigned i) const {
9389
return vector[i];
@@ -105,30 +101,25 @@ struct Vector<T, Dim>
105101
mlir::detail::isPowerOf2(sizeof(T[Dim]))> {
106102
};
107103

108-
template <int D1, typename T>
109-
using Vector1D = Vector<T, D1>;
110-
template <int D1, int D2, typename T>
111-
using Vector2D = Vector<T, D1, D2>;
104+
template <int D1, typename T> using Vector1D = Vector<T, D1>;
105+
template <int D1, int D2, typename T> using Vector2D = Vector<T, D1, D2>;
112106
template <int D1, int D2, int D3, typename T>
113107
using Vector3D = Vector<T, D1, D2, D3>;
114108
template <int D1, int D2, int D3, int D4, typename T>
115109
using Vector4D = Vector<T, D1, D2, D3, D4>;
116110

117-
template <int N>
118-
void dropFront(int64_t arr[N], int64_t *res) {
111+
template <int N> void dropFront(int64_t arr[N], int64_t *res) {
119112
for (unsigned i = 1; i < N; ++i)
120113
*(res + i - 1) = arr[i];
121114
}
122115

123116
//===----------------------------------------------------------------------===//
124117
// Codegen-compatible structures for StridedMemRef type.
125118
//===----------------------------------------------------------------------===//
126-
template <typename T, int Rank>
127-
class StridedMemrefIterator;
119+
template <typename T, int Rank> class StridedMemrefIterator;
128120

129121
/// StridedMemRef descriptor type with static rank.
130-
template <typename T, int N>
131-
struct StridedMemRefType {
122+
template <typename T, int N> struct StridedMemRefType {
132123
T *basePtr;
133124
T *data;
134125
int64_t offset;
@@ -165,8 +156,7 @@ struct StridedMemRefType {
165156
};
166157

167158
/// StridedMemRef descriptor type specialized for rank 1.
168-
template <typename T>
169-
struct StridedMemRefType<T, 1> {
159+
template <typename T> struct StridedMemRefType<T, 1> {
170160
T *basePtr;
171161
T *data;
172162
int64_t offset;
@@ -188,8 +178,7 @@ struct StridedMemRefType<T, 1> {
188178
};
189179

190180
/// StridedMemRef descriptor type specialized for rank 0.
191-
template <typename T>
192-
struct StridedMemRefType<T, 0> {
181+
template <typename T> struct StridedMemRefType<T, 0> {
193182
T *basePtr;
194183
T *data;
195184
int64_t offset;
@@ -207,8 +196,7 @@ struct StridedMemRefType<T, 0> {
207196
};
208197

209198
/// Iterate over all elements in a strided memref.
210-
template <typename T, int Rank>
211-
class StridedMemrefIterator {
199+
template <typename T, int Rank> class StridedMemrefIterator {
212200
public:
213201
using iterator_category = std::forward_iterator_tag;
214202
using value_type = T;
@@ -261,8 +249,7 @@ class StridedMemrefIterator {
261249
};
262250

263251
/// Iterate over all elements in a 0-ranked strided memref.
264-
template <typename T>
265-
class StridedMemrefIterator<T, 0> {
252+
template <typename T> class StridedMemrefIterator<T, 0> {
266253
public:
267254
using iterator_category = std::forward_iterator_tag;
268255
using value_type = T;
@@ -307,21 +294,18 @@ class StridedMemrefIterator<T, 0> {
307294
// Codegen-compatible structure for UnrankedMemRef type.
308295
//===----------------------------------------------------------------------===//
309296
// Unranked MemRef
310-
template <typename T>
311-
struct UnrankedMemRefType {
297+
template <typename T> struct UnrankedMemRefType {
312298
int64_t rank;
313299
void *descriptor;
314300
};
315301

316302
//===----------------------------------------------------------------------===//
317303
// DynamicMemRefType type.
318304
//===----------------------------------------------------------------------===//
319-
template <typename T>
320-
class DynamicMemRefIterator;
305+
template <typename T> class DynamicMemRefIterator;
321306

322307
// A reference to one of the StridedMemRef types.
323-
template <typename T>
324-
class DynamicMemRefType {
308+
template <typename T> class DynamicMemRefType {
325309
public:
326310
int64_t rank;
327311
T *basePtr;
@@ -388,8 +372,7 @@ class DynamicMemRefType {
388372
};
389373

390374
/// Iterate over all elements in a dynamic memref.
391-
template <typename T>
392-
class DynamicMemRefIterator {
375+
template <typename T> class DynamicMemRefIterator {
393376
public:
394377
using iterator_category = std::forward_iterator_tag;
395378
using value_type = T;

include/triton-shared/Analysis/MaskAnalysis.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,9 @@ struct MaskState {
9090
LogicalResult addStates(const MaskState &lhsState, const MaskState &rhsState,
9191
Location loc, OpBuilder &builder);
9292

93-
LogicalResult minStateScalar(const MaskState &lhsState, const MaskState &rhsState,
94-
Location loc, OpBuilder &builder);
93+
LogicalResult minStateScalar(const MaskState &lhsState,
94+
const MaskState &rhsState, Location loc,
95+
OpBuilder &builder);
9596

9697
LogicalResult minStates(const MaskState &lhsState, const MaskState &rhsState,
9798
Location loc, OpBuilder &builder);

include/triton-shared/AnalysisStructured/PtrAnalysis.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ const extern std::string ptrAnalysisAttr;
4646
// address, it will be collapsed to 1D. To support gather/scatter access, treat
4747
// the unstructured offset as a whole offset instead of decoding the pointer
4848
// arithmetic on it except scalar mul.
49-
// The stride is set to 1 when there's no scalar mul so it still matches the offset *
50-
// stride formula. When there're scalar muls, the stride is set to the multiplication
51-
// of all the scalar strides.
49+
// The stride is set to 1 when there's no scalar mul so it still matches the
50+
// offset * stride formula. When there're scalar muls, the stride is set to the
51+
// multiplication of all the scalar strides.
5252
struct PtrState {
5353
SmallVector<OpFoldResult> offsets;
5454
SmallVector<OpFoldResult> sizes;
@@ -321,14 +321,16 @@ class PtrAnalysis {
321321
// Operand is the result of tt.int_to_ptr.
322322
// Expected result:
323323
// Directly grab op result
324-
LogicalResult visitOperandIntToPtr(triton::IntToPtrOp intToPtrOp, PtrState &state,
325-
const Location loc, OpBuilder &builder);
324+
LogicalResult visitOperandIntToPtr(triton::IntToPtrOp intToPtrOp,
325+
PtrState &state, const Location loc,
326+
OpBuilder &builder);
326327

327328
// Operand is the result of tt.bitcast.
328329
// Expected result:
329330
// Directly grab op result
330-
LogicalResult visitOperandBitcast(triton::BitcastOp bitcastOp, PtrState &state,
331-
const Location loc, OpBuilder &builder);
331+
LogicalResult visitOperandBitcast(triton::BitcastOp bitcastOp,
332+
PtrState &state, const Location loc,
333+
OpBuilder &builder);
332334

333335
// Get the computed PtrState for the forOp's init-arg at the provided index.
334336
FailureOr<PtrState> getLoopInitArgPtrState(scf::ForOp forOp, size_t index);

include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -813,40 +813,40 @@ struct AssertConverter : public OpConversionPattern<triton::AssertOp> {
813813
Value condVal = op.getCondition();
814814

815815
auto assertMessage =
816-
llvm::formatv("Assertion `{0}` failed", op.getMessage());
817-
818-
// The condition can only be I1 or I1Tensor (integer or tensor) from TritonOps.td.
819-
// Tensors will always be RankedTensorType.
816+
llvm::formatv("Assertion `{0}` failed", op.getMessage());
817+
818+
// The condition can only be I1 or I1Tensor (integer or tensor) from
819+
// TritonOps.td. Tensors will always be RankedTensorType.
820820
if (isa<mlir::IntegerType>(condVal.getType())) {
821821
// handle scalar case
822822
rewriter.create<mlir::cf::AssertOp>(op.getLoc(), condVal,
823823
assertMessage.str());
824-
} else if (auto tensorType = dyn_cast<RankedTensorType>(condVal.getType())) {
824+
} else if (auto tensorType =
825+
dyn_cast<RankedTensorType>(condVal.getType())) {
825826
// handle tensor case
826827
int64_t rank = tensorType.getRank();
827828

828829
// create identity mapping for access pattern
829-
SmallVector<AffineMap, 3> indexingMaps{AffineMap::getMultiDimIdentityMap(rank, rewriter.getContext())};
830+
SmallVector<AffineMap, 3> indexingMaps{
831+
AffineMap::getMultiDimIdentityMap(rank, rewriter.getContext())};
830832

831833
// loops do not depend on each other
832-
SmallVector<utils::IteratorType, 3> iteratorTypes(rank, utils::IteratorType::parallel);
834+
SmallVector<utils::IteratorType, 3> iteratorTypes(
835+
rank, utils::IteratorType::parallel);
833836

834837
rewriter.create<linalg::GenericOp>(
835-
op.getLoc(),
836-
TypeRange{},
837-
condVal,
838-
ValueRange{},
839-
ArrayRef<AffineMap>{indexingMaps},
840-
ArrayRef<utils::IteratorType>{iteratorTypes},
841-
[&](OpBuilder &b, Location loc, ValueRange args) {
842-
// obtain the element in the tensor
843-
Value element = args[0];
844-
845-
// make a cf.assert for the current element
846-
b.create<mlir::cf::AssertOp>(loc, element, assertMessage.str());
847-
848-
b.create<linalg::YieldOp>(loc);
849-
});
838+
op.getLoc(), TypeRange{}, condVal, ValueRange{},
839+
ArrayRef<AffineMap>{indexingMaps},
840+
ArrayRef<utils::IteratorType>{iteratorTypes},
841+
[&](OpBuilder &b, Location loc, ValueRange args) {
842+
// obtain the element in the tensor
843+
Value element = args[0];
844+
845+
// make a cf.assert for the current element
846+
b.create<mlir::cf::AssertOp>(loc, element, assertMessage.str());
847+
848+
b.create<linalg::YieldOp>(loc);
849+
});
850850
} else {
851851
op.emitError("Unexpected type in triton::AssertOp");
852852
return failure();

include/triton-shared/Conversion/TritonArithToLinalg/ConversionTools.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
namespace mlir {
77
namespace triton {
88

9-
static inline SmallVector<utils::IteratorType> getNParallelLoopsAttrs(unsigned n) {
9+
static inline SmallVector<utils::IteratorType>
10+
getNParallelLoopsAttrs(unsigned n) {
1011
return SmallVector<utils::IteratorType>(n, utils::IteratorType::parallel);
1112
}
1213

0 commit comments

Comments
 (0)