Skip to content

Commit 6a549ae

Browse files
committed
Moirai expts
1 parent 7db2be8 commit 6a549ae

File tree

4 files changed

+61
-62
lines changed

4 files changed

+61
-62
lines changed

comp.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
import gc
22
import os
3-
import sys
43

54
import numpy as np
65
import pandas as pd
76
import torch
87

9-
src_path = os.path.abspath(os.path.join("src"))
10-
if src_path not in sys.path:
11-
sys.path.insert(0, src_path)
12-
138
from samay.dataset import MoiraiDataset
149
from samay.model import MoiraiTSModel
1510
from samay.utils import load_args
@@ -27,7 +22,7 @@ def update_leaderboard(dataset_name, model_name, metrics, leaderboard_path):
2722
"""
2823
if not os.path.exists(leaderboard_path):
2924
# Create the leaderboard with appropriate columns if it doesn't exist
30-
columns = ["Dataset"] + [
25+
columns: list[str] = ["Dataset"] + [
3126
f"{model}_{metric}"
3227
for model in ["TimesFM", "Chronos", "Moirai"]
3328
for metric in metrics.keys()
@@ -131,6 +126,6 @@ def update_leaderboard(dataset_name, model_name, metrics, leaderboard_path):
131126
eval_results, _, _, _ = moirai.evaluate(val_dataset, metrics=["MSE", "MASE"])
132127
metrics = {"MSE": eval_results["MSE"], "MASE": eval_results["MASE"]}
133128
update_leaderboard(dataset, model_name, metrics, leaderboard_path)
134-
del chronos
129+
del moirai
135130
torch.cuda.empty_cache()
136131
gc.collect()

download_data.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,46 @@
1-
from datasets import load_dataset
21
import os
32

3+
from datasets import Dataset, load_dataset
44

55
if __name__ == "__main__":
66
save_dir = "data/monash"
77
if not os.path.exists(save_dir):
88
os.makedirs(save_dir)
99
dataset_names = [
10-
"weather",
11-
"tourism_yearly",
12-
"tourism_quarterly",
13-
"tourism_monthly",
14-
"cif_2016",
15-
"london_smart_meters",
16-
"australian_electricity_demand",
17-
"wind_farms_minutely",
18-
"bitcoin",
19-
"pedestrian_counts",
20-
"vehicle_trips",
21-
"kdd_cup_2018",
22-
"nn5_daily",
23-
"nn5_weekly",
24-
"kaggle_web_traffic",
25-
"kaggle_web_traffic_weekly",
26-
"solar_10_minutes",
27-
"solar_weekly",
28-
"car_parts",
29-
"fred_md",
30-
"traffic_hourly",
31-
"traffic_weekly",
32-
"hospital",
33-
"covid_deaths",
34-
"sunspot",
35-
"saugeenday",
36-
"us_births",
37-
"solar_4_seconds",
38-
"wind_4_seconds",
39-
"rideshare",
40-
"oikolab_weather",
41-
"temperature_rain"
42-
]
10+
"weather",
11+
"tourism_yearly",
12+
"tourism_quarterly",
13+
"tourism_monthly",
14+
"cif_2016",
15+
"london_smart_meters",
16+
"australian_electricity_demand",
17+
"wind_farms_minutely",
18+
"bitcoin",
19+
"pedestrian_counts",
20+
"vehicle_trips",
21+
"kdd_cup_2018",
22+
"nn5_daily",
23+
"nn5_weekly",
24+
"kaggle_web_traffic",
25+
"kaggle_web_traffic_weekly",
26+
"solar_10_minutes",
27+
"solar_weekly",
28+
"car_parts",
29+
"fred_md",
30+
"traffic_hourly",
31+
"traffic_weekly",
32+
"hospital",
33+
"covid_deaths",
34+
"sunspot",
35+
"saugeenday",
36+
"us_births",
37+
"solar_4_seconds",
38+
"wind_4_seconds",
39+
"rideshare",
40+
"oikolab_weather",
41+
"temperature_rain",
42+
]
4343
for dataset_name in dataset_names:
44-
dataset = load_dataset("monash_tsf", dataset_name)
44+
dataset: Dataset = load_dataset("monash_tsf", dataset_name) # type: ignore
4545
dataset.save_to_disk(f"{save_dir}/{dataset_name}")
46-
print(f"Downloaded {dataset_name} dataset")
46+
print(f"Downloaded {dataset_name} dataset")

src/samay/models/moment/momentfm/utils/masking.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def convert_patch_to_seq_view(
4242
"""
4343
return mask.repeat_interleave(patch_len, dim=-1)
4444

45-
def generate_mask(self, x: torch.Tensor, input_mask: Optional[torch.Tensor] = None):
45+
def generate_mask(self, x: torch.Tensor, input_mask: torch.Tensor):
4646
"""
4747
Input:
4848
x : torch.Tensor of shape
@@ -57,8 +57,12 @@ def generate_mask(self, x: torch.Tensor, input_mask: Optional[torch.Tensor] = No
5757
return self._mask_patch_view(x, input_mask=input_mask)
5858
elif x.ndim == 3:
5959
return self._mask_seq_view(x, input_mask=input_mask)
60+
else:
61+
raise ValueError(
62+
f"Invalid input shape: {x.shape}. Expected 3D or 4D tensor."
63+
)
6064

61-
def _mask_patch_view(self, x, input_mask=None):
65+
def _mask_patch_view(self, x, input_mask: torch.Tensor):
6266
"""
6367
Input:
6468
x : torch.Tensor of shape
@@ -101,7 +105,7 @@ def _mask_patch_view(self, x, input_mask=None):
101105

102106
return mask.long()
103107

104-
def _mask_seq_view(self, x, input_mask=None):
108+
def _mask_seq_view(self, x, input_mask: torch.Tensor):
105109
"""
106110
Input:
107111
x : torch.Tensor of shape

transform_ILI.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
11
import pandas as pd
2-
from datetime import datetime
3-
42

53
if __name__ == "__main__":
64
df = pd.read_csv("data/Flu_USA/ILINet.csv")
75

8-
df = df[df['REGION TYPE'] == 'National']
9-
df = df[df['WEEK'] != 53]
6+
df = df[df["REGION TYPE"] == "National"]
7+
df = df[df["WEEK"] != 53]
108

11-
df['date'] = pd.to_datetime(df['YEAR'].astype(str) + '-W' + df['WEEK'].astype(str) + '-1', format='%Y-W%U-%w')
9+
df["date"] = pd.to_datetime(
10+
df["YEAR"].astype(str) + "-W" + df["WEEK"].astype(str) + "-1",
11+
format="%Y-W%U-%w",
12+
)
1213

1314
result = []
1415
for i in range(len(df) - 1):
15-
result.append(df.iloc[i])
16+
result.append(df.iloc[i])
1617

17-
if (df.iloc[i + 1]['date'] - df.iloc[i]['date']).days == 14:
18+
if (df.iloc[i + 1]["date"] - df.iloc[i]["date"]).days == 14:
1819
new_row = df.iloc[i].copy()
19-
new_row['date'] = df.iloc[i]['date'] + pd.Timedelta(days=7)
20+
new_row["date"] = df.iloc[i]["date"] + pd.Timedelta(days=7)
2021
result.append(new_row)
21-
22+
2223
result.append(df.iloc[-1])
2324
df = pd.DataFrame(result)
2425

25-
df = df.drop(columns=['YEAR', 'WEEK'])
26-
gaps = df['date'].diff().dropna().unique()
26+
df = df.drop(columns=["YEAR", "WEEK"])
27+
gaps = df["date"].diff().dropna().unique()
2728
print("Unique time intervals:", gaps)
28-
df['time_diff'] = df['date'].diff()
29+
df["time_diff"] = df["date"].diff()
2930

30-
rows_with_14_days = df[df['time_diff'] == pd.Timedelta(days=14)]
31+
rows_with_14_days = df[df["time_diff"] == pd.Timedelta(days=14)]
3132
print(rows_with_14_days)
32-
df = df.drop(columns=['time_diff'])
33-
infered_freq = pd.infer_freq(df['date'])
33+
df = df.drop(columns=["time_diff"])
34+
infered_freq = pd.infer_freq(df["date"])
3435
print(f"Infered frequency: {infered_freq}")
3536

3637
df.to_csv("data/Flu_USA/Flu_USA.csv", index=False)
3738

3839
print("Data saved to output.csv")
39-

0 commit comments

Comments
 (0)