Skip to content

Commit 863a713

Browse files
authored
Fix diff-train errors (#315)
* fix mis-sync caused by diff-train * remove comments * fix lint
1 parent 78d1d8b commit 863a713

File tree

6 files changed

+11
-23
lines changed

6 files changed

+11
-23
lines changed

test/test_gpu/skip_tests_h100_pytorch.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ jagged_layer_norm:
3535
jagged_mean:
3636
jagged_softmax:
3737
jagged_sum:
38+
gdpa:
3839
ragged_attention:
3940
# cpu-op for testing
4041
test_op:

test/test_gpu/skip_tests_h100_triton_main.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ jagged_layer_norm:
3535
jagged_mean:
3636
jagged_softmax:
3737
jagged_sum:
38+
gdpa:
3839
# cpu-op for testing
3940
test_op:
4041
# TODO: decoding attention requires updated xformers and flash_attn

tritonbench/operators/gdpa/gdpa.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,11 @@
3535
from .math import (
3636
activation_string_to_int,
3737
fast_gelu,
38-
fast_gelu_bf16,
39-
fast_gelu_bf16_grad,
4038
fast_gelu_grad,
41-
fast_silu,
42-
fast_silu_grad,
4339
gelu,
44-
gelu_approx,
45-
gelu_approx_grad,
4640
gelu_grad,
47-
leaky_relu,
48-
leaky_relu_grad,
4941
raw,
5042
raw_grad,
51-
relu,
52-
relu_grad,
53-
silu,
54-
silu_grad,
55-
tanh,
56-
tanh_approx_bf16,
5743
tanh_approx_fp32,
5844
)
5945

tritonbench/operators/gdpa/gdpa_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
1+
# (c) Meta Platforms, Inc. and affiliates.
22

33
# pyre-strict
44
import math

tritonbench/operators/gdpa/math.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
1+
# (c) Meta Platforms, Inc. and affiliates.
22

33
# pyre-unsafe
44

@@ -84,6 +84,7 @@ def gelu_grad(x):
8484
@triton.jit
8585
def tanh_approx_fp32(x):
8686
return tanh(x)
87+
8788
else:
8889

8990
@triton.jit

tritonbench/operators/gdpa/operator.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,10 @@
1818

1919
import argparse
2020
import gc
21-
import re
2221
from typing import Any, Callable, Generator, List, Optional
2322

2423
import torch
2524

26-
from .gdpa import gdpa
27-
from .gdpa_utils import generate_jagged_data
28-
2925
from tritonbench.utils.triton_op import (
3026
BenchmarkOperator,
3127
BenchmarkOperatorMetrics,
@@ -35,6 +31,9 @@
3531
register_x_val,
3632
)
3733

34+
from .gdpa import gdpa
35+
from .gdpa_utils import generate_jagged_data
36+
3837

3938
def calculate_memory_size(jagged_q, jagged_k, jagged_v, real_output, run_fwd, run_bwd):
4039
def tensor_size(tensor):
@@ -103,19 +102,19 @@ def parse_args(args):
103102
"--max_seq_len",
104103
default=1000,
105104
type=str,
106-
help=f"Max sequence length for Q",
105+
help="Max sequence length for Q",
107106
)
108107
parser.add_argument(
109108
"--dim",
110109
default=512,
111110
type=str,
112-
help=f"Query dimension",
111+
help="Query dimension",
113112
)
114113
parser.add_argument(
115114
"--head",
116115
default=4,
117116
type=str,
118-
help=f"Multi head number",
117+
help="Multi head number",
119118
)
120119
parser.add_argument(
121120
"--kv_len",

0 commit comments

Comments
 (0)