Skip to content

Commit 5136e5a

Browse files
authored
Move some functional APIs to fairseq2.ops (#1165)
1 parent c7616e5 commit 5136e5a

File tree

18 files changed

+229
-62
lines changed

18 files changed

+229
-62
lines changed

src/fairseq2/generation/_sampling/_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from fairseq2.models.encoder_decoder import EncoderDecoderModel
2424
from fairseq2.models.sequence import SequenceModelOutput
2525
from fairseq2.nn import BatchLayout, IncrementalStateBag
26+
from fairseq2.ops import repeat_interleave
2627
from fairseq2.utils.stopwatch import Stopwatch
27-
from fairseq2.utils.tensor import repeat_interleave
2828

2929
# isort: split
3030

src/fairseq2/models/sequence.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88

99
from abc import ABC, abstractmethod
1010
from dataclasses import dataclass, field
11-
from typing import Literal, final
11+
from typing import Literal, Protocol, final
1212

1313
from torch import Tensor
1414
from torch.nn import Module
1515

1616
from fairseq2.nn import BatchLayout
17-
from fairseq2.nn.ops import CrossEntropy, cross_entropy
17+
from fairseq2.nn.functional import cross_entropy
1818

1919

2020
class SequenceModel(Module, ABC):
@@ -36,6 +36,18 @@ def forward(
3636
) -> SequenceModelOutput: ...
3737

3838

39+
class CrossEntropy(Protocol):
40+
def __call__(
41+
self,
42+
logits: Tensor,
43+
targets: Tensor,
44+
pad_idx: int | None,
45+
*,
46+
label_smoothing: float = 0.0,
47+
reduction: Literal["sum", "mean", "none"] = "sum",
48+
) -> Tensor: ...
49+
50+
3951
@final
4052
@dataclass
4153
class SequenceModelOutput:

src/fairseq2/models/transformer/_multihead_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
PositionEncoder,
3030
Projection,
3131
)
32-
from fairseq2.utils.tensor import repeat_interleave
32+
from fairseq2.ops import repeat_interleave
3333

3434
# isort: split
3535

src/fairseq2/models/transformer_lm/_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
from typing_extensions import override
1414

1515
from fairseq2.models.decoder import DecoderModel
16-
from fairseq2.models.sequence import SequenceModelOutput
16+
from fairseq2.models.sequence import CrossEntropy, SequenceModelOutput
1717
from fairseq2.models.transformer import TransformerFrontend
1818
from fairseq2.nn import BatchLayout, IncrementalStateBag, Projection
19-
from fairseq2.nn.ops import CrossEntropy, cross_entropy
19+
from fairseq2.nn.functional import cross_entropy
2020

2121
# isort: split
2222

src/fairseq2/models/wav2vec2/_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from fairseq2.error import InternalError
2020
from fairseq2.models.transformer import TransformerEncoder
2121
from fairseq2.nn import BatchLayout, Linear
22-
from fairseq2.utils.tensor import repeat_interleave
22+
from fairseq2.ops import repeat_interleave
2323

2424
# isort: split
2525

src/fairseq2/nn/_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from fairseq2.error import InternalError
2424
from fairseq2.gang import Gang
2525
from fairseq2.nn.utils.module import to_empty
26-
from fairseq2.tensor_parallel import gather, reduce, reduce_on_backward
26+
from fairseq2.ops.tensor_parallel import gather, reduce, reduce_on_backward
2727

2828

2929
class Embedding(Module, ABC):

src/fairseq2/nn/_position_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from fairseq2.data_type import DataType
2323
from fairseq2.device import Device
2424
from fairseq2.error import InternalError
25-
from fairseq2.utils.tensor import unsqueeze
25+
from fairseq2.ops import unsqueeze
2626

2727
# isort: split
2828

src/fairseq2/nn/_projection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from fairseq2.error import InternalError
2525
from fairseq2.gang import Gang
2626
from fairseq2.nn.utils.module import to_empty
27-
from fairseq2.tensor_parallel import gather, reduce, reduce_on_backward, scatter
27+
from fairseq2.ops.tensor_parallel import gather, reduce, reduce_on_backward, scatter
2828

2929

3030
class Projection(Module, ABC):
Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,14 @@
66

77
from __future__ import annotations
88

9-
from typing import Literal, Protocol
9+
from typing import Literal
1010

1111
import torch
1212
from torch import Tensor
1313
from torch.nn.functional import log_softmax
1414
from torch.nn.functional import nll_loss as torch_nll_loss
1515

1616

17-
class CrossEntropy(Protocol):
18-
def __call__(
19-
self,
20-
logits: Tensor,
21-
targets: Tensor,
22-
pad_idx: int | None,
23-
*,
24-
label_smoothing: float = 0.0,
25-
reduction: Literal["sum", "mean", "none"] = "sum",
26-
) -> Tensor: ...
27-
28-
2917
def cross_entropy(
3018
logits: Tensor,
3119
targets: Tensor,
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from fairseq2.nn.functional._cross_entropy import cross_entropy as cross_entropy
10+
from fairseq2.nn.functional._nll_loss import nll_loss as nll_loss

0 commit comments

Comments
 (0)