Skip to content

Commit ab54fed

Browse files
authored
[hotfix] add kwargs for colo_addmm (#2171)
1 parent a110933 commit ab54fed

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

colossalai/nn/_ops/addmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def colo_addmm(input_tensor: GeneralTensor,
5555
mat2: ColoTensor,
5656
beta: Number = 1,
5757
alpha: Number = 1,
58-
*args) -> ColoTensor:
58+
**kargs) -> ColoTensor:
5959
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
6060
This method computes a linear.
6161
"""
@@ -70,7 +70,7 @@ def colo_addmm(input_tensor: GeneralTensor,
7070
assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
7171
assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
7272
ret_tensor = ColoTensor.from_torch_tensor(
73-
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha),
73+
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha, **kargs),
7474
spec=ColoTensorSpec(mat2.get_process_group()))
7575
elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
7676
if mat2.is_shard_1drow() and input_tensor.is_replicate():

0 commit comments

Comments
 (0)