Skip to content

Commit 651c637

Browse files
authored
Merge branch 'nn7_infini' into issue/586
2 parents 271d287 + 73cb007 commit 651c637

39 files changed

+4285
-337
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ pip install . -e
181181
#### 运行 InfiniCore Python算子接口测试
182182

183183
```bash
184-
python test/infinicore/run.py --verbose --bench [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun]
184+
# 测试单算子
185+
python test/infinicore/ops/[operator].py [--bench | --debug] [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun | --Hygon]
186+
# 测试全部算子
187+
python test/infinicore/run.py [--bench | --debug] [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun]
185188
```
186189

187190
使用 -h 查看更多参数。
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pragma once
2+
3+
#include "common/op.hpp"
4+
5+
namespace infinicore::op {
6+
7+
Tensor embedding(Tensor input, Tensor weight);
8+
void embedding_(Tensor out, Tensor input, Tensor weight);
9+
} // namespace infinicore::op

include/infinicore/ops/linear.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include "common/op.hpp"
4+
#include <pybind11/pybind11.h>
5+
6+
namespace infinicore::op {
7+
8+
Tensor linear(Tensor input, Tensor weight, pybind11::object bias);
9+
10+
void linear_(Tensor out, Tensor input, Tensor weight, pybind11::object bias);
11+
12+
} // namespace infinicore::op
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
#include "infinicore/tensor.hpp"
7+
8+
namespace infinicore::op {
9+
10+
class RandomSample {
11+
public:
12+
using schema = void (*)(Tensor, Tensor, float, float, int, float);
13+
static void execute(Tensor indices, Tensor logits, float random_val, float topp, int topk, float temperature);
14+
static common::OpDispatcher<schema> &dispatcher();
15+
};
16+
17+
// Out-of-place API
18+
Tensor random_sample(Tensor logits, float random_val, float topp, int topk, float temperature);
19+
// In-place API
20+
void random_sample_(Tensor indices, Tensor logits, float random_val, float topp, int topk, float temperature);
21+
22+
} // namespace infinicore::op

python/infinicore/nn/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
from infinicore.nn import (
22
functional as functional,
33
)
4+
from infinicore.nn import (
5+
modules as modules,
6+
)
7+
from infinicore.nn.functional import * # noqa: F403
8+
from infinicore.nn.modules import * # noqa: F403

python/infinicore/nn/functional.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from infinicore.lib import _infinicore
55
from infinicore.tensor import Tensor
66

7-
__all__ = ["causal_softmax", "rms_norm", "silu", "swiglu"]
7+
__all__ = ["causal_softmax", "random_sample", "rms_norm", "silu", "swiglu"]
88

99

1010
def causal_softmax(input: Tensor, out=None) -> Tensor:
@@ -105,6 +105,93 @@ def scaled_dot_product_attention(
105105
key._underlying,
106106
value._underlying,
107107
scale,
108+
def embedding(input: Tensor, weight: Tensor, *, out=None) -> Tensor:
109+
r"""Generate a simple lookup table that looks up embeddings in a fixed dictionary and size."""
110+
111+
if out is None:
112+
return Tensor(_infinicore.embedding(input._underlying, weight._underlying))
113+
114+
_infinicore.embedding_(out._underlying, input._underlying, weight._underlying)
115+
return out
116+
117+
118+
def rope(
119+
x: Tensor,
120+
pos_ids: Tensor,
121+
sin_table: Tensor,
122+
cos_table: Tensor,
123+
algo: _infinicore.Algo = _infinicore.Algo.GPT_NEOX,
124+
*,
125+
out=None,
126+
) -> Tensor:
127+
r"""Rotary Position Embedding(RoPE)."""
128+
129+
if out is None:
130+
return infinicore.Tensor(
131+
_infinicore.rope(
132+
x._underlying,
133+
pos_ids._underlying,
134+
sin_table._underlying,
135+
cos_table._underlying,
136+
algo,
137+
)
138+
)
139+
140+
_infinicore.rope_(
141+
out._underlying,
142+
x._underlying,
143+
pos_ids._underlying,
144+
sin_table._underlying,
145+
cos_table._underlying,
146+
algo,
147+
)
148+
def linear(input: Tensor, weight: Tensor, bias=None, *, out=None) -> Tensor:
149+
r"""Applies a linear transformation to the incoming data: y=xA^T+b."""
150+
151+
if out is None:
152+
return Tensor(
153+
_infinicore.linear(
154+
input._underlying,
155+
weight._underlying,
156+
None if bias is None else bias._underlying,
157+
)
158+
)
159+
160+
_infinicore.linear_(
161+
out._underlying,
162+
input._underlying,
163+
weight._underlying,
164+
None if bias is None else bias._underlying,
165+
)
166+
def random_sample(
167+
logits: Tensor,
168+
random_val: float,
169+
topp: float,
170+
topk: int,
171+
temperature: float,
172+
*,
173+
out=None,
174+
) -> Tensor:
175+
r"""Sample an index from logits with nucleus/top-k filtering."""
176+
177+
if out is None:
178+
return Tensor(
179+
_infinicore.random_sample(
180+
logits._underlying,
181+
random_val,
182+
topp,
183+
topk,
184+
temperature,
185+
)
186+
)
187+
188+
_infinicore.random_sample_(
189+
out._underlying,
190+
logits._underlying,
191+
random_val,
192+
topp,
193+
topk,
194+
temperature,
108195
)
109196

110197
return out
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .container import ModuleList
2+
from .module import Module
3+
from .parameter import Parameter
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# ============================================
2+
# Copyright (c) 2025, InfiniCore
3+
#
4+
# This file implements InfiniCoreModuleList, which is similar to torch.nn.ModuleList
5+
# but based on InfiniCoreModule for inference purposes.
6+
7+
import operator
8+
from collections import OrderedDict
9+
from itertools import chain
10+
from typing import Iterator, List, Optional, Sequence, TypeVar, Union
11+
12+
from .module import Module
13+
14+
# Define type variable for module compatibility (supports InfiniCoreModule)
15+
ModuleType = TypeVar("ModuleType", bound=Union["Module"])
16+
17+
18+
class InfiniCoreModuleList(Module):
19+
r"""Holds submodules in a list.
20+
21+
InfiniCoreModuleList can be indexed like a regular Python list, but
22+
modules it contains are properly registered, and will be visible by all
23+
InfiniCoreModule methods.
24+
25+
Args:
26+
modules (iterable, optional): an iterable of modules to add
27+
28+
Example::
29+
30+
>>> class MyModel(InfiniCoreModule):
31+
... def __init__(self):
32+
... super().__init__()
33+
... self.linears = InfiniCoreModuleList([
34+
... torch.nn.Linear(10, 10) for i in range(10)
35+
... ])
36+
...
37+
... def forward(self, x):
38+
... # ModuleList can act as an iterable, or be indexed using ints
39+
... for i, l in enumerate(self.linears):
40+
... x = self.linears[i // 2](x) + l(x)
41+
... return x
42+
"""
43+
44+
def __init__(self, modules: Optional[Sequence[ModuleType]] = None):
45+
super().__init__()
46+
if modules is not None:
47+
self += modules
48+
49+
def _get_abs_string_index(self, idx):
50+
"""Get the absolute index for the list of modules."""
51+
idx = operator.index(idx)
52+
if not (-len(self) <= idx < len(self)):
53+
raise IndexError(f"index {idx} is out of range")
54+
if idx < 0:
55+
idx += len(self)
56+
return str(idx)
57+
58+
def __getitem__(
59+
self, idx: Union[int, slice]
60+
) -> Union[ModuleType, "InfiniCoreModuleList"]:
61+
if isinstance(idx, slice):
62+
return self.__class__(list(self._modules.values())[idx])
63+
else:
64+
return self._modules[self._get_abs_string_index(idx)]
65+
66+
def __setitem__(self, idx: int, module: ModuleType) -> None:
67+
idx = self._get_abs_string_index(idx)
68+
# Use add_module to register module
69+
self.add_module(idx, module)
70+
71+
def __delitem__(self, idx: Union[int, slice]) -> None:
72+
if isinstance(idx, slice):
73+
indices_to_delete = list(range(len(self._modules)))[idx]
74+
for k in indices_to_delete:
75+
if str(k) in self._modules:
76+
del self._modules[str(k)]
77+
else:
78+
idx_str = self._get_abs_string_index(idx)
79+
if idx_str in self._modules:
80+
del self._modules[idx_str]
81+
82+
# To preserve numbering, self._modules is being reconstructed with modules after deletion
83+
if len(self._modules) > 0:
84+
str_indices = [str(i) for i in range(len(self._modules))]
85+
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
86+
87+
def __len__(self) -> int:
88+
return len(self._modules)
89+
90+
def __iter__(self) -> Iterator[ModuleType]:
91+
return iter(self._modules.values())
92+
93+
def __iadd__(self, modules: Sequence[ModuleType]) -> "InfiniCoreModuleList":
94+
return self.extend(modules)
95+
96+
def __add__(
97+
self, other: Union[Sequence[ModuleType], "InfiniCoreModuleList"]
98+
) -> "InfiniCoreModuleList":
99+
r"""Return a new InfiniCoreModuleList by concatenating with another iterable.
100+
101+
Args:
102+
other (iterable): iterable of modules to concatenate
103+
"""
104+
if not isinstance(other, (list, tuple, InfiniCoreModuleList)):
105+
raise TypeError(
106+
f"InfiniCoreModuleList can only be concatenated with list, tuple, or InfiniCoreModuleList, "
107+
f"got {type(other).__name__}"
108+
)
109+
110+
combined = InfiniCoreModuleList()
111+
for i, module in enumerate(chain(self, other)):
112+
combined.add_module(str(i), module)
113+
return combined
114+
115+
def append(self, module: ModuleType) -> "InfiniCoreModuleList":
116+
r"""Append a given module to the end of the list.
117+
118+
Args:
119+
module (InfiniCoreModule): module to append
120+
"""
121+
self.add_module(str(len(self)), module)
122+
return self
123+
124+
def extend(self, modules: Sequence[ModuleType]) -> "InfiniCoreModuleList":
125+
r"""Append modules from a Python iterable to the end of the list.
126+
127+
Args:
128+
modules (iterable): iterable of modules to append
129+
"""
130+
if not isinstance(modules, (list, tuple)):
131+
try:
132+
modules = list(modules)
133+
except TypeError:
134+
raise TypeError(
135+
f"InfiniCoreModuleList.extend should be called with an "
136+
f"iterable, but got {type(modules).__name__}"
137+
)
138+
139+
offset = len(self)
140+
for i, module in enumerate(modules):
141+
self.add_module(str(offset + i), module)
142+
return self
143+
144+
def insert(self, index: int, module: ModuleType) -> None:
145+
r"""Insert a given module before a given index in the list.
146+
147+
Args:
148+
index (int): index to insert.
149+
module ( InfiniCoreModule): module to insert
150+
"""
151+
for i in range(len(self._modules), index, -1):
152+
self._modules[str(i)] = self._modules[str(i - 1)]
153+
self._modules[str(index)] = module
154+
155+
def pop(self, idx: int = -1) -> ModuleType:
156+
r"""Remove and return a module at the given index.
157+
158+
Args:
159+
idx (int): index of the module to pop. Default: -1 (last module)
160+
161+
Returns:
162+
Module: the module that was removed
163+
"""
164+
idx_str = self._get_abs_string_index(idx)
165+
module = self._modules[idx_str]
166+
# Use __delitem__ to ensure proper cleanup
167+
self.__delitem__(int(idx_str))
168+
return module
169+
170+
def __repr__(self) -> str:
171+
"""Return a string representation of the ModuleList."""
172+
if len(self) == 0:
173+
return self.__class__.__name__ + "()"
174+
175+
lines = []
176+
for i, module in enumerate(self):
177+
lines.append(f"({i}): {repr(module)}")
178+
179+
main_str = self.__class__.__name__ + "(\n "
180+
main_str += "\n ".join(lines) + "\n)"
181+
return main_str
182+
183+
def __dir__(self) -> List[str]:
184+
"""Return a list of attribute names, excluding numeric keys."""
185+
keys = super().__dir__()
186+
# Filter out numeric keys to avoid cluttering dir() output
187+
keys = [key for key in keys if not key.isdigit()]
188+
return keys
189+
190+
191+
ModuleList = InfiniCoreModuleList

0 commit comments

Comments
 (0)