Skip to content

Commit ad40bed

Browse files
authored
[bwd] Add default backward pass function (#520)
1 parent fdb7ac7 commit ad40bed

File tree

35 files changed

+65
-145
lines changed

35 files changed

+65
-145
lines changed

tritonbench/operators/addmm/operator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
class Operator(BenchmarkOperator):
8282
DEFAULT_METRICS = ["tflops", "best_config"]
8383
DEFAULT_PRECISION = "fp16"
84+
FWD_ONLY = True
8485

8586
def __init__(
8687
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None

tritonbench/operators/bf16xint16_gemm/bf16xint16_gemm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
class Operator(BenchmarkOperator):
3737
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
38+
FWD_ONLY = True
3839

3940
def __init__(
4041
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None

tritonbench/operators/cross_entropy/operator.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,6 @@ def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
7979
v = example_inputs[0].size(-1)
8080
return (self.B, self.T, v)
8181

82-
def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
83-
y = fwd_fn()
84-
# TODO: how to pass grad_to_none=[_input]?
85-
return lambda: y.backward(retain_graph=True)
86-
8782
def get_grad_to_none(self, args) -> List[torch.Tensor]:
8883
x = args[0]
8984
return [x]

tritonbench/operators/embedding/operator.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,3 @@ def torch_compile_embedding(self, V, D, input, shared_weight) -> Callable:
6060
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
6161
V, D, input_tensor, _ = example_inputs
6262
return (input_tensor.size(0), input_tensor.size(1), D, V)
63-
64-
def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
65-
y = fwd_fn()
66-
do = torch.randn_like(y)
67-
return lambda: y.backward(do, retain_graph=True)

tritonbench/operators/flex_attention/operator.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -419,16 +419,6 @@ def sdpa_fn():
419419

420420
return sdpa_fn
421421

422-
def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
423-
o = fwd_fn()
424-
o_tensor = input_filter(
425-
lambda x: isinstance(x, torch.Tensor) and x.requires_grad,
426-
o,
427-
)
428-
assert o_tensor is not None, "No tensor found in output that requires grad."
429-
do = torch.rand_like(o_tensor)
430-
return lambda: o_tensor.backward(do, retain_graph=True)
431-
432422
def get_grad_to_none(self, args) -> List[torch.Tensor]:
433423
"""Return tensors whose gradients should be set to None between iterations."""
434424
q, k, v, *_ = args

tritonbench/operators/fp8_attention/operator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def parse_op_args(args: List[str]):
5353
class Operator(BenchmarkOperator):
5454
DEFAULT_METRICS = ["latency", "tflops"]
5555
DEFAULT_PRECISION = "fp8"
56+
FWD_ONLY = True
5657

5758
def __init__(
5859
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def parse_args(args):
5858
class Operator(BenchmarkOperator):
5959
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
6060
DEFAULT_PRECISION = "fp8"
61+
FWD_ONLY = True
6162

6263
def __init__(
6364
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None

tritonbench/operators/fp8_gemm_blockwise/operator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def fp8_block_quantize(
123123
class Operator(BenchmarkOperator):
124124
DEFAULT_METRICS = ["tflops", "speedup", "accuracy"]
125125
DEFAULT_PRECISION = "fp8"
126+
FWD_ONLY = True
126127

127128
def __init__(
128129
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None

tritonbench/operators/fp8_gemm_rowwise_grouped/operator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ class Operator(BenchmarkOperator):
333333

334334
DEFAULT_METRICS = ["tflops", "gbps", "speedup", "accuracy"]
335335
DEFAULT_PRECISION = "fp8"
336+
FWD_ONLY = True
336337

337338
def __init__(
338339
self,

tritonbench/operators/fused_linear_cross_entropy/operator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,3 @@ def torch_compile_fused_linear_cross_entropy(
108108
@register_x_val(label="(B*T, H)")
109109
def get_x_val(self, example_inputs) -> Tuple[int, int]:
110110
return (example_inputs[0].size(0), example_inputs[0].size(1))
111-
112-
def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
113-
y = fwd_fn()
114-
return lambda: y.backward(retain_graph=True)

0 commit comments

Comments
 (0)