Skip to content

Commit ae6977e

Browse files
committed
change default reduction in xla to mean
1 parent d29839b commit ae6977e

File tree

1 file changed

+3
-1
lines changed
  • src/lightning/pytorch/strategies

1 file changed

+3
-1
lines changed

src/lightning/pytorch/strategies/xla.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
247247

248248
@override
249249
def reduce(
250-
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
250+
self,
251+
output: Union[Tensor, Any], group: Optional[Any] = None,
252+
reduce_op: Optional[Union[ReduceOp, str]] = "mean"
251253
) -> Tensor:
252254
if not isinstance(output, Tensor):
253255
output = torch.tensor(output, device=self.root_device)

0 commit comments

Comments
 (0)