Skip to content

Commit 86b2d82

Browse files
Revert "[Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer (pytorch#166165)"
This reverts commit 94f2657. Reverted pytorch#166165 on behalf of https://github.com/izaitsevfb due to breaks test_LinearAndSoftmax_codegen test ([comment](pytorch#166165 (comment)))
1 parent eea8ff2 commit 86b2d82

File tree

3 files changed

+5
-71
lines changed

3 files changed

+5
-71
lines changed

test/inductor/test_torchinductor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15280,7 +15280,7 @@ def fn3(x):
1528015280
),
1528115281
(
1528215282
fn3,
15283-
"triton_poi_fused_addmm_native_layer_norm",
15283+
"triton_poi_fused_native_layer_norm_relu",
1528415284
(torch.randn(4, 4, device=GPU_TYPE),),
1528515285
),
1528615286
]
@@ -15293,7 +15293,7 @@ def fn3(x):
1529315293
),
1529415294
(
1529515295
fn3,
15296-
"triton_poi_fused_LayerNorm_Linear_ReLU",
15296+
"triton_poi_fused_LayerNorm_ReLU",
1529715297
(torch.randn(4, 4, device=GPU_TYPE),),
1529815298
),
1529915299
]

torch/_inductor/fx_passes/post_grad.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@
5151
decode_device,
5252
get_all_devices,
5353
get_gpu_type,
54-
has_uses_tagged_as,
5554
is_gpu,
55+
is_pointwise_use,
5656
OPTIMUS_EXCLUDE_POST_GRAD,
5757
)
5858
from ..virtualized import V
@@ -1510,10 +1510,8 @@ def should_prefer_unfused_addmm(match):
15101510
if not is_gpu(inp.meta["val"].device.type):
15111511
return False
15121512

1513-
return has_uses_tagged_as(
1514-
match.output_node(),
1515-
(torch.Tag.pointwise, torch.Tag.reduction),
1516-
)
1513+
output = match.output_node()
1514+
return all(is_pointwise_use(use) for use in output.users)
15171515

15181516

15191517
@register_graph_pattern(

torch/_inductor/utils.py

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -553,70 +553,6 @@ def is_pointwise_use(
553553
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
554554

555555

556-
class LogicalConnective(enum.Enum):
557-
OR = enum.auto()
558-
AND = enum.auto()
559-
560-
561-
def has_uses(
562-
target: Node,
563-
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
564-
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
565-
) -> bool:
566-
"""
567-
Given a target, explore the uses of `target` by applying `use_selector_fn`
568-
on them, and then aggregate these booleans with the `use_aggregate_type`
569-
logical connective.
570-
571-
Uses in view ops will follow the views uses.
572-
"""
573-
574-
def get_use_aggregate_fn(
575-
use_aggregate_type: LogicalConnective,
576-
) -> Callable[[Iterator[Any]], bool]:
577-
match use_aggregate_type:
578-
case LogicalConnective.AND:
579-
return all
580-
case LogicalConnective.OR:
581-
return any
582-
case _:
583-
return any
584-
585-
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
586-
587-
def has_uses_impl(use: Node) -> bool:
588-
if use.op != "call_function":
589-
return False
590-
if not (
591-
isinstance(use.target, torch._ops.OpOverload)
592-
or use.target is operator.getitem
593-
):
594-
return False
595-
596-
target = cast(torch._ops.OpOverload, use.target)
597-
# Process getitem and view
598-
if target is operator.getitem or is_view(target):
599-
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
600-
601-
return use_selector_fn(target)
602-
603-
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
604-
605-
606-
def has_uses_tagged_as(
607-
target: Node,
608-
use_tags: Collection[torch.Tag],
609-
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
610-
) -> bool:
611-
"""
612-
Is there a use with given tags?
613-
"""
614-
615-
return has_uses(
616-
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
617-
)
618-
619-
620556
def gen_gm_and_inputs(
621557
target: Any, args: list[Any], kwargs: dict[str, Any]
622558
) -> tuple[GraphModule, list[torch.Tensor]]:

0 commit comments

Comments
 (0)