Skip to content

Commit ebea9f2

Browse files
authored
convert reshape to view (#73)
1 parent c82a5a1 commit ebea9f2

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

megatron/core/tensor_parallel/layers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,7 @@ def backward(ctx, grad_output):
279279
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
280280
grad_output = grad_output.contiguous()
281281
# Convert the tensor shapes to 2D for execution compatibility
282-
# TODO: Is the reshape preventing us from getting a speedup here?
283-
grad_output = grad_output.reshape(grad_output.shape[0] * grad_output.shape[1],
282+
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
284283
grad_output.shape[2])
285284
total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
286285
total_input.shape[2])

0 commit comments

Comments
 (0)