Skip to content

Commit 7439a54

Browse files
ydwu4can-gaa-hou
authored andcommitted
[while_loop][inductor] fix aliased inputs by cloning (pytorch#160668)
[fx_graph_cse](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/compile_utils.py#L46) is executed in min_cut partitioner which accidentally creates the aliasing for empty buffers and we could see the following graph node for joint graph with cmd: "pytest test/functorch/test_control_flow.py -k test_scan_multiple_layers_gradient_layers_2_device_cpu" ```python while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0_0, while_loop_body_graph_0_0, (full_default_4, empty_strided_default, full_default_2, full_default_3, full_default_2, full_default_3, full_default, full_default, rev, rev_1, rev_2, rev_3), (primals_4, primals_5, primals_6, primals_7)); ``` Notice the operands sequence **"full_default_2, full_default_3, full_default_2, full_default_3, full_default, full_default"**, which indicates the gradient of different layers now sharing the same buffer, which create silent incorrectness. Fixes pytorch#158168. Pull Request resolved: pytorch#160668 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#160548, pytorch#160374
1 parent c74e084 commit 7439a54

File tree

2 files changed

+198
-0
lines changed

2 files changed

+198
-0
lines changed

test/functorch/test_control_flow.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2947,6 +2947,169 @@ def RNN(x: torch.Tensor, y: torch.Tensor):
29472947
params,
29482948
)
29492949

2950+
@requires_cuda
2951+
@skipIfTorchDynamo("not a dynamo test")
2952+
@unittest.skipIf(not SM70OrLater, "triton")
2953+
@parametrize("layers", [1, 2, 3])
2954+
@parametrize("device", ["cpu", "cuda"])
2955+
@torch._dynamo.config.patch(capture_scalar_outputs=True)
2956+
def test_scan_multiple_layers_gradient(self, layers, device):
2957+
import torch.nn as nn
2958+
2959+
torch.manual_seed(1)
2960+
2961+
LAYERS = layers
2962+
BATCH_SIZE = 2
2963+
SEQ_LEN = 5
2964+
FEATURE_DIM = 10
2965+
DEVICE = device
2966+
2967+
class RNNLoop(nn.Module):
2968+
def __init__(self):
2969+
super().__init__()
2970+
self.layers = nn.ModuleList(
2971+
[nn.Linear(FEATURE_DIM * 2, FEATURE_DIM) for _ in range(LAYERS)]
2972+
)
2973+
self.num_layers = LAYERS
2974+
2975+
def forward(self, initial, inputs_sequence):
2976+
B, T, _ = inputs_sequence.shape
2977+
hs_list = initial
2978+
all_out = []
2979+
for t in range(T):
2980+
input = inputs_sequence[:, t, :]
2981+
for li, layer in enumerate(self.layers):
2982+
input_concat = torch.cat((hs_list[li], input), dim=-1)
2983+
update = layer(input_concat)
2984+
hs_list[li] = hs_list[li] + update
2985+
input = hs_list[li]
2986+
2987+
all_out.append(input)
2988+
2989+
return torch.stack(all_out, dim=1)
2990+
2991+
class RNNScanList(nn.Module):
2992+
def __init__(self):
2993+
super().__init__()
2994+
self.layers = nn.ModuleList(
2995+
[nn.Linear(FEATURE_DIM * 2, FEATURE_DIM) for _ in range(LAYERS)]
2996+
)
2997+
self.num_layers = LAYERS
2998+
2999+
def forward(self, initial, input_sequence):
3000+
def step(carry, input):
3001+
hs_list = carry[:]
3002+
for li, layer in enumerate(self.layers):
3003+
h_prev_li = hs_list[li]
3004+
input_concat = torch.cat((h_prev_li, input), dim=-1)
3005+
update = layer(input_concat)
3006+
h_curr_li = h_prev_li + update
3007+
hs_list[li] = h_curr_li
3008+
input = h_curr_li
3009+
return [t.clone() for t in hs_list], input.clone()
3010+
3011+
_, all_outputs_scan = scan(step, initial, input_sequence, dim=1)
3012+
return all_outputs_scan.transpose(0, 1)
3013+
3014+
class RNNScanTensor(nn.Module):
3015+
def __init__(self):
3016+
super().__init__()
3017+
self.layers = nn.ModuleList(
3018+
[nn.Linear(FEATURE_DIM * 2, FEATURE_DIM) for _ in range(LAYERS)]
3019+
)
3020+
self.num_layers = LAYERS
3021+
3022+
def forward(self, initial, input_sequence):
3023+
def step(carry_tensor, xs_input):
3024+
input = xs_input
3025+
hs_tensor = carry_tensor
3026+
for li, layer in enumerate(self.layers):
3027+
current_h_prev_li_slice = hs_tensor[:, li, :]
3028+
input_concat = torch.cat(
3029+
(current_h_prev_li_slice, input), dim=-1
3030+
)
3031+
update = layer(input_concat)
3032+
h_curr_li = current_h_prev_li_slice + update
3033+
hs_tensor = hs_tensor.clone()
3034+
hs_tensor[:, li, :] = h_curr_li
3035+
input = h_curr_li
3036+
return hs_tensor.clone(), input.clone()
3037+
3038+
hs_stacked = torch.stack(initial, dim=1)
3039+
_, all_outputs_scan = scan(step, hs_stacked, input_sequence, dim=1)
3040+
return all_outputs_scan.transpose(0, 1)
3041+
3042+
def run_test_and_get_grads_loss(model, initial_hs, inputs):
3043+
for param in model.parameters():
3044+
if param.grad is not None:
3045+
param.grad.zero_()
3046+
3047+
current_initial_hs = [
3048+
h.detach().clone().requires_grad_(h.requires_grad) for h in initial_hs
3049+
]
3050+
current_inputs = (
3051+
inputs.detach().clone().requires_grad_(inputs.requires_grad)
3052+
)
3053+
3054+
out = model(current_initial_hs, current_inputs)
3055+
loss = out.sum()
3056+
loss.backward()
3057+
3058+
layer_grads = []
3059+
for layer in model.layers:
3060+
layer_grads.append(layer.weight.grad.clone())
3061+
3062+
return layer_grads, loss
3063+
3064+
torch.manual_seed(0)
3065+
3066+
initial_hs_template = [
3067+
torch.zeros(
3068+
BATCH_SIZE, FEATURE_DIM, requires_grad=True, dtype=torch.float32
3069+
).to(DEVICE)
3070+
for _ in range(LAYERS)
3071+
]
3072+
inputs_template = torch.randn(
3073+
BATCH_SIZE, SEQ_LEN, FEATURE_DIM, requires_grad=True, dtype=torch.float32
3074+
).to(DEVICE)
3075+
3076+
# Test 3 models: RNNScanList, RNNScanTensor, RNNLoop
3077+
models = [
3078+
("ScanList", RNNScanList),
3079+
("ScanTensor", RNNScanTensor),
3080+
("Loop", RNNLoop),
3081+
]
3082+
3083+
for model_name, model_class in models:
3084+
# Create uncompiled model
3085+
model_uc = model_class().to(DEVICE)
3086+
uncompiled_grads, uncompiled_loss = run_test_and_get_grads_loss(
3087+
model_uc, initial_hs_template, inputs_template
3088+
)
3089+
3090+
# Create compiled model with same weights
3091+
model_to_compile = model_class().to(DEVICE)
3092+
model_to_compile.load_state_dict(model_uc.state_dict())
3093+
compiled_model = torch.compile(model_to_compile)
3094+
compiled_grads, compiled_loss = run_test_and_get_grads_loss(
3095+
compiled_model, initial_hs_template, inputs_template
3096+
)
3097+
3098+
# Compare gradients for each layer
3099+
for i, (uncompiled_grad, compiled_grad) in enumerate(
3100+
zip(uncompiled_grads, compiled_grads)
3101+
):
3102+
self.assertEqual(
3103+
uncompiled_grad,
3104+
compiled_grad,
3105+
)
3106+
3107+
# Compare losses
3108+
self.assertEqual(
3109+
uncompiled_loss,
3110+
compiled_loss,
3111+
)
3112+
29503113
@unittest.skipIf(not SM70OrLater, "triton")
29513114
@requires_cuda
29523115
@parametrize("reverse", [False, True])

torch/_inductor/ir.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8525,6 +8525,8 @@ def _split_by_sym_type(
85258525

85268526
@ir_dataclass(frozen=False)
85278527
class WhileLoop(ExternKernel):
8528+
"""IR node for while_loop, which supports input mutations"""
8529+
85288530
carried_inputs: Optional[Sequence[IRNode]] = None
85298531
additional_inputs: Optional[Sequence[IRNode]] = None
85308532
cond_subgraph: Optional[Subgraph] = None
@@ -8557,6 +8559,38 @@ def __init__(
85578559
self.name = V.graph.register_buffer(self)
85588560
V.graph.register_operation(self)
85598561

8562+
# Accidental aliasing can be created due to cse, where the empty buffers we
8563+
# allocated for backward to use gets csed into the same buffer in function fx_graph_cse.
8564+
# See test_scan_multiple_layers_gradient for a concrete example.
8565+
@staticmethod
8566+
def _clone_aliased_inputs(carried_inputs: Sequence[IRNode]) -> Sequence[IRNode]:
8567+
if not _has_aliased_buffers(carried_inputs):
8568+
return carried_inputs
8569+
8570+
# Import clone from lowering module
8571+
from .lowering import clone
8572+
8573+
# Unwrap views to get the underlying buffers for comparison
8574+
unwrapped_buffers = [
8575+
buffer.unwrap_view() if isinstance(buffer, ReinterpretView) else buffer
8576+
for buffer in carried_inputs
8577+
]
8578+
8579+
# Track which buffers we've seen and their indices
8580+
seen_buffers: OrderedSet[int] = OrderedSet()
8581+
result = []
8582+
8583+
for i, (original_input, unwrapped_buffer) in enumerate(
8584+
zip(carried_inputs, unwrapped_buffers)
8585+
):
8586+
if id(unwrapped_buffer) in seen_buffers:
8587+
result.append(clone(original_input))
8588+
else:
8589+
seen_buffers.add(id(unwrapped_buffer))
8590+
result.append(original_input)
8591+
8592+
return result
8593+
85608594
@classmethod
85618595
def create(
85628596
cls,
@@ -8592,6 +8626,7 @@ def _require_exact_strides(
85928626
fake_additional_inputs = [x.meta["val"] for x in fx_additional_inputs] # type: ignore[union-attr]
85938627

85948628
carried_inputs_ = [cls.realize_input(x) for x in carried_inputs]
8629+
carried_inputs_ = WhileLoop._clone_aliased_inputs(carried_inputs_)
85958630
carried_inputs_ = _require_exact_strides(carried_inputs_, fake_carried_inputs)
85968631
additional_inputs_ = [cls.realize_input(x) for x in additional_inputs]
85978632
additional_inputs_ = _require_exact_strides(

0 commit comments

Comments
 (0)