Skip to content

Commit 3e8c6df

Browse files
author
pengcheng888
committed
issue/596 - 将functional.py中的函数,拆成functional文件夹中的函数
1 parent 1a618ff commit 3e8c6df

File tree

9 files changed

+161
-128
lines changed

9 files changed

+161
-128
lines changed

python/infinicore/nn/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from infinicore.nn import (
2-
functional as functional,
3-
)
1+
from infinicore.nn import functional
2+
3+
__all__ = ["functional"]

python/infinicore/nn/functional.py

Lines changed: 0 additions & 101 deletions
This file was deleted.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .causal_softmax import causal_softmax
2+
from .random_sample import random_sample
3+
from .rms_norm import rms_norm
4+
from .silu import silu
5+
from .swiglu import swiglu
6+
7+
__all__ = [
8+
"causal_softmax",
9+
"random_sample",
10+
"rms_norm",
11+
"silu",
12+
"swiglu",
13+
]
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
__all__ = ["causal_softmax"]
5+
6+
7+
def causal_softmax(input: Tensor, out=None) -> Tensor:
8+
r"""Apply a causal softmax function."""
9+
10+
if out is None:
11+
return Tensor(_infinicore.causal_softmax(input._underlying))
12+
13+
_infinicore.causal_softmax_(out._underlying, input._underlying)
14+
15+
return out
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
__all__ = ["random_sample"]
5+
6+
7+
def random_sample(
8+
logits: Tensor,
9+
random_val: float,
10+
topp: float,
11+
topk: int,
12+
temperature: float,
13+
*,
14+
out=None,
15+
) -> Tensor:
16+
r"""Sample an index from logits with nucleus/top-k filtering."""
17+
18+
if out is None:
19+
return Tensor(
20+
_infinicore.random_sample(
21+
logits._underlying,
22+
random_val,
23+
topp,
24+
topk,
25+
temperature,
26+
)
27+
)
28+
29+
_infinicore.random_sample_(
30+
out._underlying,
31+
logits._underlying,
32+
random_val,
33+
topp,
34+
topk,
35+
temperature,
36+
)
37+
38+
return out
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
__all__ = ["rms_norm"]
5+
6+
7+
def rms_norm(
8+
input: Tensor,
9+
normalized_shape: list[int],
10+
weight: Tensor,
11+
eps: float = 1e-5,
12+
*,
13+
out=None,
14+
) -> Tensor:
15+
r"""Apply Root Mean Square Layer Normalization."""
16+
17+
assert normalized_shape == weight.shape, (
18+
"normalized_shape does not match weight.shape."
19+
)
20+
21+
if out is None:
22+
return Tensor(_infinicore.rms_norm(input._underlying, weight._underlying, eps))
23+
24+
_infinicore.rms_norm_(out._underlying, input._underlying, weight._underlying, eps)
25+
26+
return out
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import infinicore
2+
from infinicore.lib import _infinicore
3+
from infinicore.tensor import Tensor
4+
5+
__all__ = ["silu"]
6+
7+
8+
def silu(input: Tensor, inplace: bool = False, *, out=None) -> Tensor:
9+
r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise."""
10+
11+
if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None:
12+
return infinicore.ntops.torch.silu(input, inplace=inplace)
13+
14+
if inplace:
15+
_infinicore.silu_(input._underlying, input._underlying)
16+
return input
17+
18+
if out is None:
19+
return Tensor(_infinicore.silu(input._underlying))
20+
21+
_infinicore.silu_(out._underlying, input._underlying)
22+
23+
return out
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
__all__ = ["swiglu"]
5+
6+
7+
def swiglu(input: Tensor, other: Tensor, *, out=None):
8+
r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise."""
9+
10+
if out is None:
11+
return Tensor(_infinicore.swiglu(input._underlying, other._underlying))
12+
13+
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
14+
15+
return out

test/infinicore/ops/random_sample.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,11 @@ def torch_random_sample(data, random_val, topp, topk, voc, temperature):
109109
idx = torch.searchsorted(cum_probs, threshold)
110110
except Exception:
111111
indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0]
112-
idx = indices[0] if indices.numel() > 0 else torch.tensor(len(cum_probs) - 1, device=cum_probs.device)
112+
idx = (
113+
indices[0]
114+
if indices.numel() > 0
115+
else torch.tensor(len(cum_probs) - 1, device=cum_probs.device)
116+
)
113117
return sorted_indices[idx]
114118

115119
return torch.argmax(data)
@@ -191,41 +195,41 @@ def infinicore_operator(self, logits, out=None, **kwargs):
191195
def run_test(self, device, test_case, config):
192196
"""
193197
Override run_test to handle random_sample's special comparison logic.
194-
198+
195199
For random_sample, if the indices differ but the logits values at those
196200
indices are equal, the result is still considered valid. This handles
197201
cases where multiple valid indices exist due to floating-point precision.
198-
202+
199203
This is necessary because random_sample can return different valid indices
200204
when multiple positions have the same logits value, especially with
201205
low-precision types like bfloat16 due to floating-point rounding.
202206
"""
203207
# Clear stored logits before test to ensure fresh generation
204208
self._current_logits = None
205-
209+
206210
try:
207211
# Try the standard comparison first
208212
# This will call prepare_inputs_and_kwargs which will set self._current_logits
209213
return super().run_test(device, test_case, config)
210-
except AssertionError:
214+
except AssertionError as original_error:
211215
# If standard comparison fails, check if this is a valid case where
212216
# indices differ but logits values are equal
213-
217+
214218
# Only handle if we have stored logits (from prepare_inputs_and_kwargs)
215219
if self._current_logits is None:
216220
raise
217-
221+
218222
logits_tensor = self._current_logits
219-
223+
220224
# Re-run operations with the same logits to get results for comparison
221225
# prepare_inputs_and_kwargs will reuse self._current_logits if it exists
222226
from framework.utils import (
223227
infinicore_tensor_from_torch,
224228
convert_infinicore_to_torch,
225229
)
226-
230+
227231
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)
228-
232+
229233
# Prepare infinicore inputs
230234
infini_inputs = []
231235
for inp in inputs:
@@ -235,51 +239,51 @@ def run_test(self, device, test_case, config):
235239
infini_inputs.append(infini_tensor)
236240
else:
237241
infini_inputs.append(inp)
238-
242+
239243
infini_kwargs = kwargs.copy()
240-
if "out" in infini_kwargs and isinstance(infini_kwargs["out"], torch.Tensor):
244+
if "out" in infini_kwargs and isinstance(
245+
infini_kwargs["out"], torch.Tensor
246+
):
241247
cloned_out = infini_kwargs["out"].clone().detach()
242248
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)
243-
249+
244250
# Run both operators
245251
torch_result = self.torch_operator(*inputs, **kwargs)
246252
infini_result = self.infinicore_operator(*infini_inputs, **infini_kwargs)
247-
253+
248254
# Extract indices from results
249255
comparison_target = test_case.comparison_target
250256
if comparison_target == "out":
251257
# Compare output tensor from kwargs
252258
ref_idx = kwargs["out"].item()
253259
torch_result_from_infini = convert_infinicore_to_torch(
254-
infini_kwargs["out"], kwargs["out"]
260+
infini_kwargs["out"]
255261
)
256262
ic_idx = torch_result_from_infini.item()
257263
else:
258264
# Compare return values
259265
ref_idx = torch_result.item()
260-
torch_result_from_infini = convert_infinicore_to_torch(
261-
infini_result, torch_result
262-
)
266+
torch_result_from_infini = convert_infinicore_to_torch(infini_result)
263267
ic_idx = torch_result_from_infini.item()
264-
268+
265269
# Check if indices are equal (standard case)
266270
if ic_idx == ref_idx:
267-
return
268-
271+
return True, "passed"
272+
269273
# Special case: indices differ but logits values are equal
270274
# This is valid for random_sample when multiple indices have the same logits value
271275
try:
272276
logits_ref = logits_tensor[ref_idx].item()
273277
logits_ic = logits_tensor[ic_idx].item()
274278
if logits_ic == logits_ref:
275279
# Valid: different indices but same logits value
276-
return
280+
return True, "passed"
277281
except (IndexError, RuntimeError):
278282
# If we can't access the logits, fall through to raise the original error
279283
pass
280-
284+
281285
# If we get here, the results are truly different
282-
raise
286+
raise original_error
283287

284288

285289
def main():

0 commit comments

Comments
 (0)