diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py b/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py deleted file mode 100644 index 3c5ad7600328..000000000000 --- a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py +++ /dev/null @@ -1,108 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import abc -import math - -import torch -from einops import rearrange -from torch import nn - - -class AttentionBias(nn.Module, abc.ABC): - def __init__(self, dim: int, num_heads: int): - super().__init__() - assert num_heads > 0 and dim % num_heads == 0 - - self.num_heads = num_heads - self.head_dim = dim // num_heads - - @abc.abstractmethod - def forward(self, query_id, kv_id): ... - - -class BinaryAttentionBias(AttentionBias): - def __init__(self, dim: int, num_heads: int): - super().__init__(dim, num_heads) - self.emb = nn.Embedding(num_embeddings=2, embedding_dim=self.num_heads) - - def forward(self, query_id, kv_id): - ind = torch.eq(query_id.unsqueeze(-1), kv_id.unsqueeze(-2)) - weight = rearrange(self.emb.weight, "two num_heads -> two num_heads 1 1") - bias = ~ind * weight[:1] + ind * weight[1:] - return bias - - -def _relative_position_bucket( - relative_position, bidirectional=True, num_buckets=32, max_distance=128 -): - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min( - relative_position, torch.zeros_like(relative_position) - ) - - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) - relative_position_if_large = torch.min( - relative_position_if_large, - torch.full_like(relative_position_if_large, num_buckets - 1), - ) - - relative_buckets += torch.where( - is_small, relative_position, relative_position_if_large - ) - return relative_buckets - - -class T5AttentionBias(AttentionBias): - def __init__(self, dim: int, num_heads: int): - super().__init__(dim, num_heads) - self.num_buckets = 32 - self.max_distance = 32 - self.relative_attention_bias = nn.Embedding(self.num_buckets, 1) - - def forward(self, n_vars, n_tokens): - context_position = torch.arange( - n_tokens, - dtype=torch.long, - )[:, None] - memory_position = torch.arange( - n_tokens, - dtype=torch.long, - )[None, :] - relative_position = memory_position - context_position - bucket = _relative_position_bucket( - relative_position=relative_position, - bidirectional=False, - num_buckets=self.num_buckets, - max_distance=self.max_distance, - ).to(self.relative_attention_bias.weight.device) - bias = self.relative_attention_bias(bucket).squeeze(-1) - bias = bias.reshape(1, 1, bias.shape[0], bias.shape[1]) - mask1 = torch.ones((n_vars, n_vars), dtype=torch.bool).to(bias.device) - final_bias = torch.kron(mask1, bias) - return final_bias diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py b/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py deleted file mode 100644 index 18e2b29c3d6e..000000000000 --- a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py +++ /dev/null @@ -1,127 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import abc -from functools import cached_property - -import torch -from einops import einsum, rearrange, repeat -from torch import nn - - -class Projection(nn.Module, abc.ABC): - def __init__(self, proj_width: int, num_heads: int, **kwargs): - super().__init__() - self.proj_width = proj_width - self.num_heads = num_heads - - @abc.abstractmethod - def forward(self, x, seq_id): ... - - -class RotaryProjection(Projection): - def __init__( - self, *, proj_width: int, num_heads: int, max_len: int = 512, base: int = 10000 - ): - super().__init__(proj_width, num_heads) - assert ( - self.proj_width % 2 == 0 - ), f"proj_width must be even, got {self.proj_width}" - self.register_buffer( - "theta", - 1.0 - / torch.pow( - base, - torch.arange(0, self.proj_width, 2, dtype=torch.float) - / self.proj_width, - ), - persistent=False, - ) - self.register_buffer("cos", None, persistent=False) - self.register_buffer("sin", None, persistent=False) - self._init_freq(max_len=max_len) - - def _init_freq(self, max_len: int): - if self.cos is None or self.cos.size(-2) < max_len: - position = torch.arange( - max_len, device=self.theta.device, dtype=self.theta.dtype - ) - m_theta = einsum(position, self.theta, "length, width -> length width") - m_theta = repeat(m_theta, "length width -> length (width 2)") - self.register_buffer("cos", torch.cos(m_theta), persistent=False) - self.register_buffer("sin", torch.sin(m_theta), persistent=False) - - @staticmethod - def _rotate(x): - x1, x2 = rearrange(x, "... (dim r) -> r ... dim", r=2) - return rearrange([-x2, x1], "r ... dim -> ... (dim r)", r=2) # noqa - - def forward(self, x, seq_id): - self._init_freq(max_len=seq_id.max() + 1) - rot_cos = self.cos[seq_id] - rot_sin = self.sin[seq_id] - return rot_cos * x + rot_sin * self._rotate(x) - - -class QueryKeyProjection(nn.Module): - def __init__( - self, dim: int, num_heads: int, proj_layer, kwargs=None, partial_factor=None - ): - super().__init__() - if partial_factor is not None: - assert ( - 0.0 <= partial_factor[0] < partial_factor[1] <= 1.0 - ), f"got {partial_factor[0]}, {partial_factor[1]}" - assert num_heads > 0 and dim % num_heads == 0 - - self.head_dim = dim // num_heads - self.partial_factor = partial_factor - self.query_proj = proj_layer( - proj_width=self.proj_width, - num_heads=num_heads, - **(kwargs or {}), - ) - self.key_proj = self.query_proj - - @cached_property - def proj_width(self) -> int: - if self.partial_factor is None: - return self.head_dim - return int(self.head_dim * (self.partial_factor[1] - self.partial_factor[0])) - - @cached_property - def split_sizes(self): - if self.partial_factor is None: - return 0, self.head_dim, 0 - return ( - int(self.partial_factor[0] * self.head_dim), - self.proj_width, - int((1.0 - self.partial_factor[1]) * self.head_dim), - ) - - def forward(self, query, key, query_id, kv_id): - if self.partial_factor is not None: - queries = list(query.split(self.split_sizes, dim=-1)) - keys = list(key.split(self.split_sizes, dim=-1)) - queries[1] = self.query_proj(queries[1], seq_id=query_id) - keys[1] = self.key_proj(keys[1], seq_id=kv_id) - query = torch.cat(queries, dim=-1) - key = torch.cat(keys, dim=-1) - else: - query = self.query_proj(query, seq_id=query_id) - key = self.key_proj(key, seq_id=kv_id) - return query, key diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Embed.py b/iotdb-core/ainode/ainode/TimerXL/layers/Embed.py deleted file mode 100644 index 8c3cf570bafe..000000000000 --- a/iotdb-core/ainode/ainode/TimerXL/layers/Embed.py +++ /dev/null @@ -1,290 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import math - -import torch -import torch.nn as nn -from torch.jit import is_scripting - -from ainode.TimerXL.models.configuration_timer import TimerxlConfig - - -class PositionalEmbedding(nn.Module): - def __init__(self, d_model, max_len=6500): - super(PositionalEmbedding, self).__init__() - # Compute the positional encodings once in log space. - pe = torch.zeros(max_len, d_model).float() - pe.require_grad = False - - position = torch.arange(0, max_len).float().unsqueeze(1) - div_term = ( - torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) - ).exp() - - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - - pe = pe.unsqueeze(0) - self.register_buffer("pe", pe) - - def forward(self, x): - return self.pe[:, : x.size(1)] - - -class TokenEmbedding(nn.Module): - def __init__(self, c_in, d_model): - super(TokenEmbedding, self).__init__() - padding = 1 if torch.__version__ >= "1.5.0" else 2 - self.tokenConv = nn.Conv1d( - in_channels=c_in, - out_channels=d_model, - kernel_size=3, - padding=padding, - padding_mode="circular", - bias=False, - ) - for m in self.modules(): - if isinstance(m, nn.Conv1d): - nn.init.kaiming_normal_( - m.weight, mode="fan_in", nonlinearity="leaky_relu" - ) - - def forward(self, x): - x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) - return x - - -class FixedEmbedding(nn.Module): - def __init__(self, c_in, d_model): - super(FixedEmbedding, self).__init__() - - w = torch.zeros(c_in, d_model).float() - w.require_grad = False - - position = torch.arange(0, c_in).float().unsqueeze(1) - div_term = ( - torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) - ).exp() - - w[:, 0::2] = torch.sin(position * div_term) - w[:, 1::2] = torch.cos(position * div_term) - - self.emb = nn.Embedding(c_in, d_model) - self.emb.weight = nn.Parameter(w, requires_grad=False) - - def forward(self, x): - return self.emb(x).detach() - - -class TemporalEmbedding(nn.Module): - def __init__(self, d_model, embed_type="fixed", freq="h"): - super(TemporalEmbedding, self).__init__() - - minute_size = 4 - hour_size = 24 - weekday_size = 7 - day_size = 32 - month_size = 13 - - Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding - if freq == "t": - self.minute_embed = Embed(minute_size, d_model) - self.hour_embed = Embed(hour_size, d_model) - self.weekday_embed = Embed(weekday_size, d_model) - self.day_embed = Embed(day_size, d_model) - self.month_embed = Embed(month_size, d_model) - - def forward(self, x): - x = x.long() - minute_x = ( - self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0 - ) - hour_x = self.hour_embed(x[:, :, 3]) - weekday_x = self.weekday_embed(x[:, :, 2]) - day_x = self.day_embed(x[:, :, 1]) - month_x = self.month_embed(x[:, :, 0]) - - return hour_x + weekday_x + day_x + month_x + minute_x - - -class TimeFeatureEmbedding(nn.Module): - def __init__(self, d_model, embed_type="timeF", freq="h"): - super(TimeFeatureEmbedding, self).__init__() - - freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3} - d_inp = freq_map[freq] - self.embed = nn.Linear(d_inp, d_model, bias=False) - - def forward(self, x): - return self.embed(x) - - -class DataEmbedding(nn.Module): - def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): - super(DataEmbedding, self).__init__() - - self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) - self.position_embedding = PositionalEmbedding(d_model=d_model) - self.temporal_embedding = ( - TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) - if embed_type != "timeF" - else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) - ) - self.dropout = nn.Dropout(p=dropout) - - def forward(self, x, x_mark): - if x_mark is None: - x = self.value_embedding(x) + self.position_embedding(x) - else: - x = ( - self.value_embedding(x) - + self.temporal_embedding(x_mark) - + self.position_embedding(x) - ) - return self.dropout(x) - - -class DataEmbedding_inverted(nn.Module): - def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): - super(DataEmbedding_inverted, self).__init__() - self.value_embedding = nn.Linear(c_in, d_model) - self.dropout = nn.Dropout(p=dropout) - - def forward(self, x, x_mark): - x = x.permute(0, 2, 1) - # x: [Batch Variate Time] - if x_mark is None: - x = self.value_embedding(x) - else: - x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) - # x: [Batch Variate d_model] - return self.dropout(x) - - -class DataEmbedding_wo_pos(nn.Module): - def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): - super(DataEmbedding_wo_pos, self).__init__() - - self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) - self.position_embedding = PositionalEmbedding(d_model=d_model) - self.temporal_embedding = ( - TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) - if embed_type != "timeF" - else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) - ) - self.dropout = nn.Dropout(p=dropout) - - def forward(self, x, x_mark): - if x_mark is None: - x = self.value_embedding(x) - else: - x = self.value_embedding(x) + self.temporal_embedding(x_mark) - return self.dropout(x) - - -class PatchEmbedding(nn.Module): - def __init__(self, d_model, patch_len, stride, padding, dropout): - super(PatchEmbedding, self).__init__() - # Patching - self.patch_len = patch_len - self.stride = stride - self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) - - # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space - self.value_embedding = nn.Linear(patch_len, d_model, bias=False) - - # Positional embedding - self.position_embedding = PositionalEmbedding(d_model) - - # Residual dropout - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - # do patching - n_vars = x.shape[1] - x = self.padding_patch_layer(x) - x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) - x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) - # Input encoding - x = self.value_embedding(x) + self.position_embedding(x) - return self.dropout(x), n_vars - - -class TimerPatchEmbedding(nn.Module): - def __init__(self, config: TimerxlConfig): - super().__init__() - self.input_token_len = config.input_token_len - self.emb = nn.Linear(config.input_token_len, config.hidden_size, bias=False) - - def forward(self, hidden_state: torch.Tensor): - hidden_state = hidden_state.unfold( - dimension=-1, size=self.input_token_len, step=self.input_token_len - ) - return self.emb(hidden_state) - - -class TimeMoeRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None): - super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.max_seq_len_cached: int = 0 - inv_freq = 1.0 / ( - self.base - ** ( - torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) - / self.dim - ) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype(), - ) - - def _set_cos_sin_cache( - self, seq_len: int, device: torch.device, dtype: torch.dtype - ): - self.max_seq_len_cached = int(seq_len) - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=torch.int64 - ).type_as(self.inv_freq) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - if not is_scripting(): - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - else: - self.cos_cached = emb.cos().to(dtype) - self.sin_cached = emb.sin().to(dtype) - - def forward(self, x, seq_len: int = 0): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/SelfAttention_Family.py b/iotdb-core/ainode/ainode/TimerXL/layers/SelfAttention_Family.py deleted file mode 100644 index 4a2fb0d27e09..000000000000 --- a/iotdb-core/ainode/ainode/TimerXL/layers/SelfAttention_Family.py +++ /dev/null @@ -1,207 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -from math import sqrt -from typing import Any, Optional, Tuple - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import repeat - -from ainode.core.util.huggingface_cache import Cache, DynamicCache -from ainode.core.util.masking import ( - TimerCovariateMask, - TimerMultivariateMask, - TriangularCausalMask, -) -from ainode.TimerXL.layers.Attn_Bias import BinaryAttentionBias -from ainode.TimerXL.layers.Attn_Projection import QueryKeyProjection, RotaryProjection -from ainode.TimerXL.layers.Embed import TimeMoeRotaryEmbedding -from ainode.TimerXL.models.configuration_timer import TimerxlConfig - - -def rotate_half(x): - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class FullAttention(nn.Module): - def __init__( - self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False - ): - super(FullAttention, self).__init__() - self.scale = scale - self.mask_flag = mask_flag - self.output_attention = output_attention - self.dropout = nn.Dropout(attention_dropout) - - def forward( - self, - queries, - keys, - values, - attn_mask, - n_vars=None, - n_tokens=None, - tau=None, - delta=None, - ): - B, L, H, E = queries.shape - _, S, _, D = values.shape - scale = self.scale or 1.0 / sqrt(E) - - scores = torch.einsum("blhe,bshe->bhls", queries, keys) - - if self.mask_flag: - if attn_mask is None: - attn_mask = TriangularCausalMask(B, L, device=queries.device) - - scores.masked_fill_(attn_mask.mask, -np.inf) - - A = self.dropout(torch.softmax(scale * scores, dim=-1)) - V = torch.einsum("bhls,bshd->blhd", A, values) - - if self.output_attention: - return V.contiguous(), A - else: - return V.contiguous(), None - - -class TimerAttention(nn.Module): - def __init__(self, config: TimerxlConfig, layer_idx: Optional[int] = None): - super().__init__() - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.attention_dropout = config.attention_dropout - self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - self.rotary_emb = TimeMoeRotaryEmbedding( - self.head_dim, max_position_embeddings=config.max_position_embeddings - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional["Cache"] = None, - ) -> Tuple[torch.Tensor, Optional["Cache"]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - - if past_key_value is not None: - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx - ) - - attn_output = F.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attention_mask, - dropout_p=self.attention_dropout, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - return attn_output, past_key_value - - -class AttentionLayer(nn.Module): - def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None): - super(AttentionLayer, self).__init__() - - d_keys = d_keys or (d_model // n_heads) - d_values = d_values or (d_model // n_heads) - - self.inner_attention = attention - self.query_projection = nn.Linear(d_model, d_keys * n_heads) - self.key_projection = nn.Linear(d_model, d_keys * n_heads) - self.value_projection = nn.Linear(d_model, d_values * n_heads) - self.out_projection = nn.Linear(d_values * n_heads, d_model) - self.n_heads = n_heads - - def forward( - self, - queries, - keys, - values, - attn_mask, - n_vars=None, - n_tokens=None, - tau=None, - delta=None, - ): - B, L, _ = queries.shape - _, S, _ = keys.shape - H = self.n_heads - - queries = self.query_projection(queries).view(B, L, H, -1) - keys = self.key_projection(keys).view(B, S, H, -1) - values = self.value_projection(values).view(B, S, H, -1) - - out, attn = self.inner_attention( - queries, - keys, - values, - attn_mask, - n_vars=n_vars, - n_tokens=n_tokens, - tau=tau, - delta=delta, - ) - out = out.view(B, L, -1) - - return self.out_projection(out), attn diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Transformer_EncDec.py b/iotdb-core/ainode/ainode/TimerXL/layers/Transformer_EncDec.py deleted file mode 100644 index d5bad30ea055..000000000000 --- a/iotdb-core/ainode/ainode/TimerXL/layers/Transformer_EncDec.py +++ /dev/null @@ -1,329 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -from typing import Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ainode.core.util.activation import ACT2FN -from ainode.core.util.huggingface_cache import Cache, DynamicCache -from ainode.TimerXL.layers.SelfAttention_Family import TimerAttention -from ainode.TimerXL.models.configuration_timer import TimerxlConfig - - -class EncoderLayer(nn.Module): - def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): - super(EncoderLayer, self).__init__() - d_ff = d_ff or 4 * d_model - self.attention = attention - self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) - self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout = nn.Dropout(dropout) - self.activation = F.relu if activation == "relu" else F.gelu - - def forward(self, x, attn_mask=None, tau=None, delta=None): - new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta) - x = x + self.dropout(new_x) - - y = x = self.norm1(x) - y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) - y = self.dropout(self.conv2(y).transpose(-1, 1)) - - return self.norm2(x + y), attn - - -class DecoderLayer(nn.Module): - def __init__( - self, - self_attention, - cross_attention, - d_model, - d_ff=None, - dropout=0.1, - activation="relu", - ): - super(DecoderLayer, self).__init__() - d_ff = d_ff or 4 * d_model - self.self_attention = self_attention - self.cross_attention = cross_attention - self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) - self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout = nn.Dropout(dropout) - self.activation = F.relu if activation == "relu" else F.gelu - - def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): - x = x + self.dropout( - self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0] - ) - x = self.norm1(x) - - x = x + self.dropout( - self.cross_attention( - x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta - )[0] - ) - - y = x = self.norm2(x) - y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) - y = self.dropout(self.conv2(y).transpose(-1, 1)) - - return self.norm3(x + y) - - -class DecoderOnlyLayer(nn.Module): - def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): - super(DecoderOnlyLayer, self).__init__() - d_ff = d_ff or 4 * d_model - self.attention = attention - self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) - self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout = nn.Dropout(dropout) - self.activation = F.relu if activation == "relu" else F.gelu - - def forward(self, x, attn_mask=None, tau=None, delta=None): - new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta) - x = x + self.dropout(new_x) - - y = x = self.norm1(x) - y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) - y = self.dropout(self.conv2(y).transpose(-1, 1)) - - return self.norm2(x + y), attn - - -class TimerLayer(nn.Module): - def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): - super(TimerLayer, self).__init__() - d_ff = d_ff or 4 * d_model - self.attention = attention - self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) - self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout = nn.Dropout(dropout) - self.activation = F.relu if activation == "relu" else F.gelu - - def forward(self, x, n_vars, n_tokens, attn_mask=None, tau=None, delta=None): - new_x, attn = self.attention( - x, - x, - x, - n_vars=n_vars, - n_tokens=n_tokens, - attn_mask=attn_mask, - tau=tau, - delta=delta, - ) - x = x + self.dropout(new_x) - - y = x = self.norm1(x) - y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) - y = self.dropout(self.conv2(y).transpose(-1, 1)) - - return self.norm2(x + y), attn - - -class Encoder(nn.Module): - def __init__(self, attn_layers, conv_layers=None, norm_layer=None): - super(Encoder, self).__init__() - self.attn_layers = nn.ModuleList(attn_layers) - self.conv_layers = ( - nn.ModuleList(conv_layers) if conv_layers is not None else None - ) - self.norm = norm_layer - - def forward(self, x, attn_mask=None, tau=None, delta=None): - # x [B, L, D] - attns = [] - if self.conv_layers is not None: - for i, (attn_layer, conv_layer) in enumerate( - zip(self.attn_layers, self.conv_layers) - ): - delta = delta if i == 0 else None - x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) - x = conv_layer(x) - attns.append(attn) - x, attn = self.attn_layers[-1](x, tau=tau, delta=None) - attns.append(attn) - else: - for attn_layer in self.attn_layers: - x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) - attns.append(attn) - - if self.norm is not None: - x = self.norm(x) - - return x, attns - - -class Decoder(nn.Module): - def __init__(self, layers, norm_layer=None, projection=None): - super(Decoder, self).__init__() - self.layers = nn.ModuleList(layers) - self.norm = norm_layer - self.projection = projection - - def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): - for layer in self.layers: - x = layer( - x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta - ) - - if self.norm is not None: - x = self.norm(x) - - if self.projection is not None: - x = self.projection(x) - return x - - -class DecoderOnly(nn.Module): - def __init__(self, attn_layers, conv_layers=None, norm_layer=None): - super(DecoderOnly, self).__init__() - self.attn_layers = nn.ModuleList(attn_layers) - self.conv_layers = ( - nn.ModuleList(conv_layers) if conv_layers is not None else None - ) - self.norm = norm_layer - - def forward(self, x, attn_mask=None, tau=None, delta=None): - # x [B, L, D] - attns = [] - if self.conv_layers is not None: - for i, (attn_layer, conv_layer) in enumerate( - zip(self.attn_layers, self.conv_layers) - ): - delta = delta if i == 0 else None - x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) - x = conv_layer(x) - attns.append(attn) - x, attn = self.attn_layers[-1](x, tau=tau, delta=None) - attns.append(attn) - else: - for attn_layer in self.attn_layers: - x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) - attns.append(attn) - - if self.norm is not None: - x = self.norm(x) - - return x, attns - - -class TimerBlock(nn.Module): - def __init__(self, attn_layers, conv_layers=None, norm_layer=None): - super(TimerBlock, self).__init__() - self.attn_layers = nn.ModuleList(attn_layers) - self.conv_layers = ( - nn.ModuleList(conv_layers) if conv_layers is not None else None - ) - self.norm = norm_layer - - def forward(self, x, n_vars, n_tokens, attn_mask=None, tau=None, delta=None): - # x [B, L, D] - attns = [] - if self.conv_layers is not None: - for i, (attn_layer, conv_layer) in enumerate( - zip(self.attn_layers, self.conv_layers) - ): - delta = delta if i == 0 else None - x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) - x = conv_layer(x) - attns.append(attn) - x, attn = self.attn_layers[-1](x, n_vars, n_tokens, tau=tau, delta=None) - attns.append(attn) - else: - for attn_layer in self.attn_layers: - x, attn = attn_layer( - x, n_vars, n_tokens, attn_mask=attn_mask, tau=tau, delta=delta - ) - attns.append(attn) - - if self.norm is not None: - x = self.norm(x) - - return x, attns - - -class TimerMLP(nn.Module): - def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str): - super().__init__() - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[hidden_act] - - def forward(self, hidden_state): - return self.down_proj( - self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) - ) - - -class TimerDecoderLayer(nn.Module): - def __init__(self, config: TimerxlConfig, layer_idx: int): - super().__init__() - self.self_attn = TimerAttention(config, layer_idx) - - self.ffn_layer = TimerMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - ) - self.norm1 = torch.nn.LayerNorm(config.hidden_size) - self.norm2 = torch.nn.LayerNorm(config.hidden_size) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - use_cache: bool = False, - ) -> Tuple[torch.FloatTensor, Optional[Cache]]: - residual = hidden_states - - # Self Attention - hidden_states, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - ) - - hidden_states = residual + hidden_states - hidden_states = self.norm1(hidden_states) - - # Fully Connected - residual = hidden_states - hidden_states = self.ffn_layer(hidden_states) - hidden_states = residual + hidden_states - hidden_states = self.norm2(hidden_states) - - if not use_cache: - present_key_value = None - return hidden_states, present_key_value diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/__init__.py b/iotdb-core/ainode/ainode/TimerXL/layers/__init__.py deleted file mode 100644 index 2a1e720805f2..000000000000 --- a/iotdb-core/ainode/ainode/TimerXL/layers/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# diff --git a/iotdb-core/ainode/ainode/TimerXL/models/__init__.py b/iotdb-core/ainode/ainode/TimerXL/models/__init__.py deleted file mode 100644 index 2a1e720805f2..000000000000 --- a/iotdb-core/ainode/ainode/TimerXL/models/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# diff --git a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py deleted file mode 100644 index b3962a052a52..000000000000 --- a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py +++ /dev/null @@ -1,446 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import os -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple - -import torch -from huggingface_hub import hf_hub_download -from safetensors.torch import load_file as load_safetensors -from torch import nn - -from ainode.core.log import Logger -from ainode.core.util.huggingface_cache import Cache, DynamicCache -from ainode.core.util.masking import prepare_4d_causal_attention_mask -from ainode.TimerXL.layers.Embed import TimerPatchEmbedding -from ainode.TimerXL.layers.Transformer_EncDec import TimerDecoderLayer -from ainode.TimerXL.models.configuration_timer import TimerxlConfig - -logger = Logger() - - -@dataclass -class Output: - outputs: torch.Tensor - past_key_values: Optional[Any] = None - - -class TimerModel(nn.Module): - def __init__(self, config: TimerxlConfig): - super().__init__() - self.config = config - self.embed_layer = TimerPatchEmbedding(config) - self.layers = nn.ModuleList( - [ - TimerDecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ] - ) - self.norm = torch.nn.LayerNorm(config.hidden_size) - self.gradient_checkpointing = False - - def forward( - self, - input_ids: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: bool = None, - ): - # input_ids is the input of time series, its shape is [batch_size, seq_len] - - if input_ids is not None: - batch_size, seq_length = input_ids.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - inputs_embeds = self.embed_layer(input_ids) - - seq_length = inputs_embeds.shape[1] - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - # 4d mask is passed through the layers - attention_mask = prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - - hidden_states = inputs_embeds - - # decoder layers - next_decoder_cache = None - - for decoder_layer in self.layers: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[1] - - hidden_states = self.norm(hidden_states) - - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if use_legacy_cache - else next_decoder_cache - ) - - return Output(outputs=hidden_states, past_key_values=next_cache) - - -class TimerForPrediction(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.model = TimerModel(self.config) - lm_head_list = [] - self.output_token_len_map = {} - for i, output_token_len in enumerate(self.config.output_token_lens): - lm_head_list.append( - nn.Linear(self.config.hidden_size, output_token_len, bias=False) - ) - self.output_token_len_map[output_token_len] = i - self.lm_heads = nn.ModuleList(lm_head_list) - self.loss_function = torch.nn.MSELoss(reduction="none") - - def forward( - self, - input_ids: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: Optional[bool] = None, - max_output_length: Optional[int] = None, - revin: Optional[bool] = True, - ): - if revin: - means, stdev = input_ids.mean(dim=-1, keepdim=True), input_ids.std( - dim=-1, keepdim=True - ) - input_ids = (input_ids - means) / stdev - - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - ) - hidden_states = outputs.outputs - - if max_output_length is None: - output_token_len = self.config.output_token_lens[0] - max_output_length = output_token_len - else: - output_token_len = self.config.output_token_lens[0] - for h in self.config.output_token_lens[1:]: - if h > max_output_length: - break - else: - output_token_len = h - - lm_head = self.lm_heads[self.output_token_len_map[output_token_len]] - predictions = lm_head(hidden_states)[:, -1, :] - - if output_token_len > max_output_length: - predictions = predictions[:, :max_output_length] - if revin: - predictions = predictions * stdev + means - - return Output(predictions, outputs.past_key_values) - - -class Model(nn.Module): - """ - Timer-XL: Long-Context Transformers for Unified Time Series Forecasting - - Paper: https://arxiv.org/abs/2410.04803 - - GitHub: https://github.com/thuml/Timer-XL - - Citation: @article{liu2024timer, - title={Timer-XL: Long-Context Transformers for Unified Time Series Forecasting}, - author={Liu, Yong and Qin, Guo and Huang, Xiangdong and Wang, Jianmin and Long, Mingsheng}, - journal={arXiv preprint arXiv:2410.04803}, - year={2024} - } - """ - - def __init__(self, config: TimerxlConfig): - super().__init__() - self.config = config # can't be scripted by torch - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model = TimerForPrediction(config).to(self.device) - - if config.ckpt_path is not None and config.ckpt_path != "": - if config.ckpt_path.endswith(".pt") or config.ckpt_path.endswith(".pth"): - state_dict = torch.load(config.ckpt_path) - elif config.ckpt_path.endswith(".safetensors"): - if not os.path.exists(config.ckpt_path): - logger.info( - f"Checkpoint not found at {config.ckpt_path}, downloading from HuggingFace..." - ) - repo_id = "thuml/timer-base-84m" - try: - config.ckpt_path = hf_hub_download( - repo_id=repo_id, - filename=os.path.basename(config.ckpt_path), - local_dir=os.path.dirname(config.ckpt_path), - ) - logger.info(f"Got checkpoint to {config.ckpt_path}") - except Exception as e: - logger.error( - f"Failed to download checkpoint to {config.ckpt_path} due to {e}" - ) - raise e - state_dict = load_safetensors(config.ckpt_path) - else: - raise ValueError("unsupported model weight type") - # If there is no key beginning with 'model.model' in state_dict, add a 'model.' before all keys. (The model code here has an additional layer of encapsulation compared to the code on huggingface.) - if not any(k.startswith("model.model") for k in state_dict.keys()): - state_dict = {"model." + k: v for k, v in state_dict.items()} - self.load_state_dict(state_dict, strict=True) - - def set_device(self, device): - self.model.to(device) - self.device = next(self.model.parameters()).device - - def inference(self, x, max_new_tokens: int = 96): - # x.shape: [L, C], type: DataFrame - # here we only except C=1 temporarily - # change [L, C=1] to [batchsize=1, L] - self.device = next(self.model.parameters()).device - - x = torch.tensor( - x, dtype=next(self.model.parameters()).dtype, device=self.device - ) - x = x.view(1, -1) - - preds = self.forward(x, max_new_tokens) - preds = preds.detach().cpu().numpy() - - return preds - - def forward(self, x, max_new_tokens: int = 96): - # self.config.is_encoder_decoder = False - self.eval() - self.device = next(self.model.parameters()).device - - if len(x.shape) == 2: - batch_size, cur_len = x.shape - if cur_len < self.config.input_token_len: - raise ValueError( - f"Input length must be at least {self.config.input_token_len}" - ) - elif cur_len % self.config.input_token_len != 0: - new_len = ( - cur_len // self.config.input_token_len - ) * self.config.input_token_len - x = x[:, -new_len:] - else: - raise ValueError("Input shape must be: [batch_size, seq_len]") - - use_cache = self.config.use_cache - all_input_ids = x - - attention_mask = self.prepare_attention_mask_for_generation(all_input_ids) - all_input_ids_length = all_input_ids.shape[-1] - max_length = max_new_tokens + all_input_ids_length - - all_input_ids = all_input_ids.to(self.device) - batch_size, cur_len = all_input_ids.shape - - unfinished_sequences = torch.ones( - batch_size, dtype=torch.long, device=all_input_ids.device - ) - cache_position = torch.arange(cur_len, device=all_input_ids.device) - true_seq_len = cur_len // self.config.input_token_len - attention_mask = attention_mask[:, -true_seq_len:] - - this_peer_finished = False - past_key_values = None - position_ids = None - while not this_peer_finished: - (input_ids, position_ids, past_key_values, attention_mask, revin) = ( - self.prepare_inputs_for_generation( - all_input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - # position_ids=position_ids # Wrong?! - position_ids=None, # True?! based on huggingface code - ) - ) - - input_length = all_input_ids.shape[1] - - # forward pass to get next token - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - max_output_length=max_length - input_length, - revin=revin, - ) - - next_tokens = outputs.outputs - - # update generated ids, model inputs, and length for next step - horizon_length = next_tokens.shape[1] // self.config.input_token_len - - all_input_ids = torch.cat([all_input_ids, next_tokens], dim=-1) - (past_key_values, attention_mask, cache_position) = ( - self._update_model_kwargs_for_generation( - outputs, - attention_mask=attention_mask, - horizon_length=horizon_length, - cache_position=cache_position, - ) - ) - - unfinished_sequences = unfinished_sequences & ( - all_input_ids.shape[1] < max_length - ) - this_peer_finished = unfinished_sequences.max() == 0 - - if all_input_ids.shape[1] > max_length: - all_input_ids = all_input_ids[:, :max_length] - - return all_input_ids[:, -(max_length - cur_len) :] - - def prepare_attention_mask_for_generation( - self, - inputs: torch.Tensor, - ) -> torch.LongTensor: - return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - revin=True, - position_ids=None, - ): - # Omit tokens covered by past_key_values - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - if isinstance(past_key_values, DynamicCache): - past_length = past_key_values.seen_tokens - else: - past_length = cache_length - - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > ( - input_ids.shape[1] // self.config.input_token_len - ): - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < (input_ids.shape[1] // self.config.input_token_len): - input_ids = input_ids[:, past_length * self.config.input_token_len :] - # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + (input_ids.shape[1] // self.config.input_token_len) - > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[ - :, -(input_ids.shape[1] // self.config.input_token_len) : - ] - - return (input_ids, position_ids, past_key_values, attention_mask, revin) - - def _update_model_kwargs_for_generation( - self, - outputs, - attention_mask=None, - cache_position=None, - horizon_length: int = 1, - ) -> Dict[str, Any]: - # update past_key_values - past_key_values = outputs.past_key_values - - # update attention mask - if attention_mask is not None: - attention_mask = torch.cat( - [ - attention_mask, - attention_mask.new_ones((attention_mask.shape[0], horizon_length)), - ], - dim=-1, - ) - - if cache_position is not None: - cache_position = cache_position[-1:] + horizon_length - - return (past_key_values, attention_mask, cache_position) diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/ainode/core/manager/inference_manager.py index 175961e73134..eb8becd0f177 100644 --- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py @@ -57,7 +57,9 @@ def infer(self, full_data, predict_length=96, **_): data = full_data[1][0] if data.dtype.byteorder not in ("=", "|"): data = data.byteswap().newbyteorder() - output = self.model.inference(data, int(predict_length)) + seqs = torch.tensor(data).unsqueeze(0).float() + # TODO: unify model inference input + output = self.model.generate(seqs, max_new_tokens=predict_length, revin=True) df = pd.DataFrame(output[0]) return convert_to_binary(df) diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py index 6298fb6a1db3..8bd3bfc48002 100644 --- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py +++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py @@ -41,8 +41,8 @@ from ainode.core.log import Logger from ainode.core.model.sundial import modeling_sundial from ainode.core.model.sundial.configuration_sundial import SundialConfig -from ainode.TimerXL.models import timer_xl -from ainode.TimerXL.models.configuration_timer import TimerxlConfig +from ainode.core.model.timerxl import modeling_timer +from ainode.core.model.timerxl.configuration_timer import TimerConfig logger = Logger() @@ -113,7 +113,7 @@ def fetch_built_in_model(model_id, inference_attributes): elif model_id == BuiltInModelType.STRAY.value: model = STRAYModel(attributes) elif model_id == BuiltInModelType.TIMER_XL.value: - model = timer_xl.Model(TimerxlConfig.from_dict(attributes)) + model = modeling_timer.TimerForPrediction(TimerConfig.from_dict(attributes)) elif model_id == BuiltInModelType.SUNDIAL.value: model = modeling_sundial.SundialForPrediction( SundialConfig.from_dict(attributes) diff --git a/iotdb-core/ainode/ainode/TimerXL/__init__.py b/iotdb-core/ainode/ainode/core/model/timerxl/__init__.py similarity index 100% rename from iotdb-core/ainode/ainode/TimerXL/__init__.py rename to iotdb-core/ainode/ainode/core/model/timerxl/__init__.py diff --git a/iotdb-core/ainode/ainode/TimerXL/models/configuration_timer.py b/iotdb-core/ainode/ainode/core/model/timerxl/configuration_timer.py similarity index 68% rename from iotdb-core/ainode/ainode/TimerXL/models/configuration_timer.py rename to iotdb-core/ainode/ainode/core/model/timerxl/configuration_timer.py index ac5034aa85ec..34f9de91b633 100644 --- a/iotdb-core/ainode/ainode/TimerXL/models/configuration_timer.py +++ b/iotdb-core/ainode/ainode/core/model/timerxl/configuration_timer.py @@ -15,27 +15,30 @@ # specific language governing permissions and limitations # under the License. # + from typing import List +from transformers import PretrainedConfig + -class TimerxlConfig: - model_type = "timerxl" +class TimerConfig(PretrainedConfig): + model_type = "timer" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, - input_token_len: int = 96, # how many points as a token, don't change - hidden_size: int = 1024, # model hidden size - intermediate_size: int = 2048, # ffn middle size - output_token_lens: List[int] = [96], # how many points as a token, don't change + input_token_len: int = 1, + hidden_size: int = 1024, + intermediate_size: int = 2048, + output_token_lens: List[int] = [1, 8, 32, 64], num_hidden_layers: int = 8, num_attention_heads: int = 8, - hidden_act: str = "silu", # activation function - use_cache: bool = True, # kv cache - rope_theta: int = 10000, # ROBE parameter + hidden_act: str = "silu", + use_cache: bool = True, + rope_theta: int = 10000, attention_dropout: float = 0.0, - initializer_range: float = 0.02, # be of no use, because we already have weights + initializer_range: float = 0.02, max_position_embeddings: int = 10000, - ckpt_path: str = None, # weight path **kwargs, ): self.input_token_len = input_token_len @@ -50,12 +53,7 @@ def __init__( self.attention_dropout = attention_dropout self.initializer_range = initializer_range self.max_position_embeddings = max_position_embeddings - self.ckpt_path = ckpt_path super().__init__( **kwargs, ) - - @classmethod - def from_dict(cls, config_dict: dict) -> "TimerxlConfig": - return cls(**config_dict) diff --git a/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py b/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py new file mode 100644 index 000000000000..42b3a82b9724 --- /dev/null +++ b/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py @@ -0,0 +1,680 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file as load_safetensors +from torch import nn +from transformers import Cache, DynamicCache, PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) + +from ainode.core.log import Logger +from ainode.core.model.timerxl.configuration_timer import TimerConfig +from ainode.core.model.timerxl.ts_generation_mixin import TSGenerationMixin + +logger = Logger() + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class TimerPatchEmbedding(nn.Module): + def __init__(self, config: TimerConfig): + super().__init__() + self.input_token_len = config.input_token_len + self.emb = nn.Linear(config.input_token_len, config.hidden_size, bias=False) + + def forward(self, hidden_state: torch.Tensor): + hidden_state = hidden_state.unfold( + dimension=-1, size=self.input_token_len, step=self.input_token_len + ) + return self.emb(hidden_state) + + +class TimerPointEmbedding(nn.Module): + def __init__(self, config: TimerConfig): + super().__init__() + self.emb_layer = nn.Linear( + config.input_token_len, config.hidden_size, bias=False + ) + self.gate_layer = nn.Linear( + config.input_token_len, config.hidden_size, bias=False + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + emb = self.act_fn(self.gate_layer(x)) * self.emb_layer(x) + return emb + + +class TimeMoeRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + ** ( + torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) + / self.dim + ) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=torch.int64 + ).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class TimerAttention(nn.Module): + def __init__(self, config: TimerConfig, layer_idx: Optional[int] = None): + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.attention_dropout = config.attention_dropout + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.rotary_emb = TimeMoeRotaryEmbedding( + self.head_dim, max_position_embeddings=config.max_position_embeddings + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx + ) + + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + dropout_p=self.attention_dropout, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class TimerMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, hidden_state): + return self.down_proj( + self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) + ) + + +class TimerDecoderLayer(nn.Module): + def __init__(self, config: TimerConfig, layer_idx: int): + super().__init__() + self.self_attn = TimerAttention(config, layer_idx) + + self.ffn_layer = TimerMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.norm1 = torch.nn.LayerNorm(config.hidden_size) + self.norm2 = torch.nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + Optional[torch.FloatTensor], + Optional[torch.FloatTensor], + ]: + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + hidden_states = self.norm1(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.ffn_layer(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.norm2(hidden_states) + + if not output_attentions: + self_attn_weights = None + + if not use_cache: + present_key_value = None + return hidden_states, self_attn_weights, present_key_value + + +class TimerPreTrainedModel(PreTrainedModel): + config_class = TimerConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["TimeMoeDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = False + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, torch.nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, torch.nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class TimerModel(TimerPreTrainedModel): + def __init__(self, config: TimerConfig): + super().__init__(config) + self.embed_layer = TimerPatchEmbedding(config) + self.layers = nn.ModuleList( + [ + TimerDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = torch.nn.LayerNorm(config.hidden_size) + self.gradient_checkpointing = False + + def forward( + self, + input_ids: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + # input_ids is the input of time series, its shape is [batch_size, seq_len] + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_layer(input_ids) + seq_length = inputs_embeds.shape[1] + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + past_key_values_length = 0 + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + # position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + position_ids = position_ids.view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=None, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if use_cache: + next_decoder_cache = layer_outputs[2] + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class TimerForPrediction(TimerPreTrainedModel, TSGenerationMixin): + def __init__(self, config: TimerConfig): + super().__init__(config) + self.config = config + self.model = TimerModel(self.config) + lm_head_list = [] + self.output_token_len_map = {} + for i, output_token_len in enumerate(self.config.output_token_lens): + lm_head_list.append( + nn.Linear(self.config.hidden_size, output_token_len, bias=False) + ) + self.output_token_len_map[output_token_len] = i + self.lm_heads = nn.ModuleList(lm_head_list) + self.loss_function = torch.nn.MSELoss(reduction="none") + # TODO: Unify data loader + if not os.path.exists(config.ckpt_path): + os.mkdir(config.ckpt_path) + weights_path = os.path.join(config.ckpt_path, "model.safetensors") + if not os.path.exists(weights_path): + logger.info( + f"Weight not found at {weights_path}, downloading from HuggingFace..." + ) + repo_id = "thuml/sundial-base-128m" + try: + hf_hub_download( + repo_id=repo_id, + filename="model.safetensors", + local_dir=config.ckpt_path, + ) + logger.info(f"Got weight to {weights_path}") + except Exception as e: + logger.error(f"Failed to download weight to {weights_path} due to {e}") + raise e + state_dict = load_safetensors(weights_path) + self.load_state_dict(state_dict, strict=True) + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + loss_masks: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + max_output_length: Optional[int] = None, + revin: Optional[bool] = False, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if revin: + mean, std = input_ids.mean(dim=-1, keepdim=True), input_ids.std( + dim=-1, keepdim=True + ) + input_ids = (input_ids - mean) / std + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state + predictions = None + + loss = None + if labels is not None: + ar_loss = 0.0 + for lm_head, output_token_len in zip( + self.lm_heads, self.config.output_token_lens + ): + one_predictions = lm_head(hidden_states) + one_loss = self.calc_ar_loss( + one_predictions, labels, loss_masks, output_token_len + ) + ar_loss += one_loss + if predictions is None: + predictions = one_predictions + loss = ar_loss / len(self.config.output_token_lens) + else: + if max_output_length is None: + output_token_len = self.config.output_token_lens[0] + max_output_length = output_token_len + else: + output_token_len = self.config.output_token_lens[0] + for h in self.config.output_token_lens[1:]: + if h > max_output_length: + break + else: + output_token_len = h + lm_head = self.lm_heads[self.output_token_len_map[output_token_len]] + predictions = lm_head(hidden_states)[:, -1, :] + if output_token_len > max_output_length: + predictions = predictions[:, :max_output_length] + if revin: + predictions = predictions * std + mean + if not return_dict: + output = (predictions,) + outputs[1:] + return (loss) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + logits=predictions, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def calc_ar_loss(self, predictions, labels, loss_masks, output_token_len): + seq_len = predictions.shape[1] * self.config.input_token_len + labels = labels[:, : seq_len - self.config.input_token_len + output_token_len] + shift_labels = labels.unfold( + dimension=-1, size=output_token_len, step=self.config.input_token_len + ) + + # Calculate loss with mask + losses = self.loss_function(predictions, shift_labels).mean(dim=-1) + if loss_masks is not None: + losses = losses * loss_masks + loss = losses.sum() / loss_masks.sum() + else: + loss = torch.mean(losses) + + return loss + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + revin=True, + **kwargs, + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + if isinstance(past_key_values, DynamicCache): + past_length = past_key_values.seen_tokens + else: + past_length = cache_length + + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > ( + input_ids.shape[1] // self.config.input_token_len + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < (input_ids.shape[1] // self.config.input_token_len): + input_ids = input_ids[:, past_length * self.config.input_token_len :] + # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + (input_ids.shape[1] // self.config.input_token_len) + > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[ + :, -(input_ids.shape[1] // self.config.input_token_len) : + ] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "revin": revin, + } + ) + return model_inputs diff --git a/iotdb-core/ainode/ainode/core/model/timerxl/ts_generation_mixin.py b/iotdb-core/ainode/ainode/core/model/timerxl/ts_generation_mixin.py new file mode 100644 index 000000000000..165d3c55e448 --- /dev/null +++ b/iotdb-core/ainode/ainode/core/model/timerxl/ts_generation_mixin.py @@ -0,0 +1,366 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList +from transformers.generation import EosTokenCriteria, validate_stopping_criteria +from transformers.generation.utils import ( + GenerateDecoderOnlyOutput, + GenerateEncoderDecoderOutput, + GenerateNonBeamOutput, + GenerateOutput, + GenerationConfig, +) +from transformers.utils import ModelOutput + + +class TSGenerationMixin(GenerationMixin): + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + if len(inputs.shape) == 2: + batch_size, cur_len = inputs.shape + if cur_len < self.config.input_token_len: + raise ValueError( + f"Input length must be at least {self.config.input_token_len}" + ) + elif cur_len % self.config.input_token_len != 0: + new_len = ( + cur_len // self.config.input_token_len + ) * self.config.input_token_len + inputs = inputs[:, -new_len:] + else: + raise ValueError("Input shape must be: [batch_size, seq_len]") + return super().generate( + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + **kwargs, + ) + + def _greedy_search( + self, + input_ids: torch.Tensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + streamer: Optional["BaseStreamer"] = None, + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.Tensor]: + input_ids = input_ids.to(self.device) + batch_size, cur_len = input_ids.shape + # init values + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + if eos_token_id is not None: + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # remove when the method is totally private + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() + for criteria in stopping_criteria + if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + raw_logits = () if (return_dict_in_generate and output_logits) else None + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") + if output_attentions + else None + ) + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states + else None + ) + + # keep track of which sequences are already finished + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] + this_peer_finished = False + unfinished_sequences = torch.ones( + batch_size, dtype=torch.long, device=input_ids.device + ) + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + true_seq_len = cur_len // self.config.input_token_len + model_kwargs["attention_mask"] = model_kwargs["attention_mask"][ + :, -true_seq_len: + ] + max_length = stopping_criteria.max_length + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device + ): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + input_length = input_ids.shape[1] + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + max_output_length=max_length - input_length, + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits + + # pre-process distribution + next_tokens_scores = logits_processor(input_ids, next_token_logits) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_tokens_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # argmax + # next_tokens = torch.argmax(next_tokens_scores, dim=-1) + next_tokens = next_tokens_scores + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) + + # update generated ids, model inputs, and length for next step + horizon_length = next_tokens.shape[1] // self.config.input_token_len + + input_ids = torch.cat([input_ids, next_tokens], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + horizon_length=horizon_length, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + unfinished_sequences = unfinished_sequences & ~stopping_criteria( + input_ids, scores + ) + this_peer_finished = unfinished_sequences.max() == 0 + + if input_ids.shape[1] > max_length: + input_ids = input_ids[:, :max_length] + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids[:, -(max_length - cur_len) :] + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + horizon_length: int = 1, + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + if getattr(outputs, "state", None) is not None: + model_kwargs["state"] = outputs.state + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 + ) + + if not is_encoder_decoder: + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [ + attention_mask, + attention_mask.new_ones( + (attention_mask.shape[0], horizon_length) + ), + ], + dim=-1, + ) + else: + # update decoder attention mask + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + model_kwargs["decoder_attention_mask"] = torch.cat( + [ + decoder_attention_mask, + decoder_attention_mask.new_ones( + (decoder_attention_mask.shape[0], horizon_length) + ), + ], + dim=-1, + ) + + if ( + "cache_position" in model_kwargs + and model_kwargs["cache_position"] is not None + ): + model_kwargs["cache_position"] = ( + model_kwargs["cache_position"][-1:] + horizon_length + ) + + return model_kwargs