Skip to content

Commit e513f17

Browse files
authored
[AINode] Integrate python code standardization (#15593)
1 parent b3cb21f commit e513f17

36 files changed

+2448
-1155
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
name: AINode Code Style Check
2+
3+
on:
4+
push:
5+
branches:
6+
- master
7+
- "rc/*"
8+
paths:
9+
- 'iotdb-core/ainode/**'
10+
pull_request:
11+
branches:
12+
- master
13+
- "rc/*"
14+
paths:
15+
- 'iotdb-core/ainode/**'
16+
# allow manually run the action:
17+
workflow_dispatch:
18+
19+
concurrency:
20+
group: ${{ github.workflow }}-${{ github.ref }}
21+
cancel-in-progress: true
22+
23+
env:
24+
MAVEN_OPTS: -Dhttp.keepAlive=false -Dmaven.wagon.http.pool=false -Dmaven.wagon.http.retryHandler.class=standard -Dmaven.wagon.http.retryHandler.count=3
25+
MAVEN_ARGS: --batch-mode --no-transfer-progress
26+
27+
jobs:
28+
check-style:
29+
runs-on: ubuntu-latest
30+
steps:
31+
- name: Checkout code
32+
uses: actions/checkout@v4
33+
34+
- name: Set up Python 3.10
35+
uses: actions/setup-python@v5
36+
with:
37+
python-version: "3.10"
38+
39+
- name: Install dependencies
40+
run: |
41+
pip3 install black==25.1.0 isort==6.0.1
42+
- name: Check code formatting (Black)
43+
run: |
44+
cd iotdb-core/ainode
45+
black --check .
46+
continue-on-error: false
47+
48+
- name: Check import order (Isort)
49+
run: |
50+
cd iotdb-core/ainode
51+
isort --check-only --profile black .
52+
continue-on-error: false

iotdb-core/ainode/ainode/TimerXL/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
#
17+
#

iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#
1818
import abc
1919
import math
20+
2021
import torch
2122
from einops import rearrange
2223
from torch import nn
@@ -41,22 +42,23 @@ def __init__(self, dim: int, num_heads: int):
4142

4243
def forward(self, query_id, kv_id):
4344
ind = torch.eq(query_id.unsqueeze(-1), kv_id.unsqueeze(-2))
44-
weight = rearrange(
45-
self.emb.weight, "two num_heads -> two num_heads 1 1")
45+
weight = rearrange(self.emb.weight, "two num_heads -> two num_heads 1 1")
4646
bias = ~ind * weight[:1] + ind * weight[1:]
4747
return bias
4848

4949

50-
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
50+
def _relative_position_bucket(
51+
relative_position, bidirectional=True, num_buckets=32, max_distance=128
52+
):
5153
relative_buckets = 0
5254
if bidirectional:
5355
num_buckets //= 2
54-
relative_buckets += (relative_position >
55-
0).to(torch.long) * num_buckets
56+
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
5657
relative_position = torch.abs(relative_position)
5758
else:
58-
relative_position = - \
59-
torch.min(relative_position, torch.zeros_like(relative_position))
59+
relative_position = -torch.min(
60+
relative_position, torch.zeros_like(relative_position)
61+
)
6062

6163
max_exact = num_buckets // 2
6264
is_small = relative_position < max_exact
@@ -66,12 +68,13 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
6668
* (num_buckets - max_exact)
6769
).to(torch.long)
6870
relative_position_if_large = torch.min(
69-
relative_position_if_large, torch.full_like(
70-
relative_position_if_large, num_buckets - 1)
71+
relative_position_if_large,
72+
torch.full_like(relative_position_if_large, num_buckets - 1),
7173
)
7274

73-
relative_buckets += torch.where(is_small,
74-
relative_position, relative_position_if_large)
75+
relative_buckets += torch.where(
76+
is_small, relative_position, relative_position_if_large
77+
)
7578
return relative_buckets
7679

7780

@@ -83,11 +86,21 @@ def __init__(self, dim: int, num_heads: int):
8386
self.relative_attention_bias = nn.Embedding(self.num_buckets, 1)
8487

8588
def forward(self, n_vars, n_tokens):
86-
context_position = torch.arange(n_tokens, dtype=torch.long,)[:, None]
87-
memory_position = torch.arange(n_tokens, dtype=torch.long, )[None, :]
89+
context_position = torch.arange(
90+
n_tokens,
91+
dtype=torch.long,
92+
)[:, None]
93+
memory_position = torch.arange(
94+
n_tokens,
95+
dtype=torch.long,
96+
)[None, :]
8897
relative_position = memory_position - context_position
89-
bucket = _relative_position_bucket(relative_position=relative_position, bidirectional=False,
90-
num_buckets=self.num_buckets, max_distance=self.max_distance).to(self.relative_attention_bias.weight.device)
98+
bucket = _relative_position_bucket(
99+
relative_position=relative_position,
100+
bidirectional=False,
101+
num_buckets=self.num_buckets,
102+
max_distance=self.max_distance,
103+
).to(self.relative_attention_bias.weight.device)
91104
bias = self.relative_attention_bias(bucket).squeeze(-1)
92105
bias = bias.reshape(1, 1, bias.shape[0], bias.shape[1])
93106
mask1 = torch.ones((n_vars, n_vars), dtype=torch.bool).to(bias.device)

iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
# under the License.
1717
#
1818
import abc
19-
import torch
2019
from functools import cached_property
20+
21+
import torch
2122
from einops import einsum, rearrange, repeat
2223
from torch import nn
2324

@@ -33,7 +34,9 @@ def forward(self, x, seq_id): ...
3334

3435

3536
class RotaryProjection(Projection):
36-
def __init__(self, *, proj_width: int, num_heads: int, max_len: int = 512, base: int = 10000):
37+
def __init__(
38+
self, *, proj_width: int, num_heads: int, max_len: int = 512, base: int = 10000
39+
):
3740
super().__init__(proj_width, num_heads)
3841
assert (
3942
self.proj_width % 2 == 0
@@ -57,8 +60,7 @@ def _init_freq(self, max_len: int):
5760
position = torch.arange(
5861
max_len, device=self.theta.device, dtype=self.theta.dtype
5962
)
60-
m_theta = einsum(position, self.theta,
61-
"length, width -> length width")
63+
m_theta = einsum(position, self.theta, "length, width -> length width")
6264
m_theta = repeat(m_theta, "length width -> length (width 2)")
6365
self.register_buffer("cos", torch.cos(m_theta), persistent=False)
6466
self.register_buffer("sin", torch.sin(m_theta), persistent=False)
@@ -76,7 +78,9 @@ def forward(self, x, seq_id):
7678

7779

7880
class QueryKeyProjection(nn.Module):
79-
def __init__(self, dim: int, num_heads: int, proj_layer, kwargs=None, partial_factor=None):
81+
def __init__(
82+
self, dim: int, num_heads: int, proj_layer, kwargs=None, partial_factor=None
83+
):
8084
super().__init__()
8185
if partial_factor is not None:
8286
assert (

0 commit comments

Comments
 (0)