Skip to content

Commit 7db2be8

Browse files
authored
Merge pull request #32 from AdityaLab/Shiduo
gifteval experiments
2 parents 97f7ee5 + c0fa3a6 commit 7db2be8

File tree

5 files changed

+205
-12
lines changed

5 files changed

+205
-12
lines changed

leaderboard.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def calc_pred_and_context_len(freq):
146146

147147
if __name__ == "__main__":
148148

149-
for model_name in MODEL_NAMES[3:]:
149+
for model_name in ["ttm"]:
150150
print(f"Evaluating model: {model_name}")
151151
# create csv file for leaderboard if not already created
152152
csv_path = f"leaderboard/{model_name}.csv"
@@ -197,8 +197,10 @@ def calc_pred_and_context_len(freq):
197197
dataset_path = f"data/gifteval/{fname}/{freq}/data.csv"
198198

199199
if model_name == "timesfm":
200-
model = TimesfmModel(**args)
200+
201201
dataset = TimesfmDataset(datetime_col='timestamp', path=dataset_path, mode='test', context_len=args["config"]["context_len"], horizon_len=args["config"]["horizon_len"], boundaries=(-1, -1, -1), batchsize=64)
202+
args["config"]["horizon_len"] = dataset.horizon_len
203+
model = TimesfmModel(**args)
202204
start = time.time()
203205
metrics = model.evaluate(dataset)
204206
print("Metrics: ", metrics)
@@ -207,10 +209,12 @@ def calc_pred_and_context_len(freq):
207209
print(f"Time taken for evaluation of {fname}: {end-start:.2f} seconds")
208210

209211
elif model_name == "moment":
210-
model = MomentModel(**args)
212+
211213
args["config"]["task_name"] = "forecasting"
212214
train_dataset = MomentDataset(datetime_col='timestamp', path=dataset_path, mode='train', horizon_len=args["config"]["forecast_horizon"], normalize=False)
213215
dataset = MomentDataset(datetime_col='timestamp', path=dataset_path, mode='test', horizon_len=args["config"]["forecast_horizon"], normalize=False, boundaries=[-1, -1, -1])
216+
args["config"]["forecast_horizon"] = dataset.forecast_horizon
217+
model = MomentModel(**args)
214218
finetuned_model = model.finetune(train_dataset, task_name="forecasting")
215219
start = time.time()
216220
metrics = model.evaluate(dataset, task_name="forecasting")
@@ -220,11 +224,13 @@ def calc_pred_and_context_len(freq):
220224
print(metrics)
221225

222226
elif model_name == "chronos":
223-
model = ChronosModel(**args)
227+
224228
dataset_config = load_args("config/chronos_dataset.json")
225229
dataset_config["context_length"] = context_len
226230
dataset_config["prediction_length"] = pred_len
227231
dataset = ChronosDataset(datetime_col='timestamp', path=dataset_path, mode='test', config=dataset_config, batch_size=4, boundaries=[-1, -1, -1])
232+
args["config"]["context_length"] = dataset.horizon_len
233+
model = ChronosModel(**args)
228234
start = time.time()
229235
metrics = model.evaluate(dataset, horizon_len=dataset_config["prediction_length"], quantile_levels=[0.1, 0.5, 0.9])
230236
end = time.time()
@@ -242,8 +248,10 @@ def calc_pred_and_context_len(freq):
242248
print(f"Time taken for evaluation of {fname}: {end-start:.2f} seconds")
243249

244250
elif model_name == "ttm":
245-
model = TinyTimeMixerModel(**args)
251+
246252
dataset = TinyTimeMixerDataset(datetime_col='timestamp', path=dataset_path, mode='test', context_len=context_len, horizon_len=pred_len, boundaries=[-1, -1, -1])
253+
args["config"]["horizon_len"] = dataset.horizon_len
254+
model = TinyTimeMixerModel(**args)
247255
start = time.time()
248256
metrics = model.evaluate(dataset)
249257
end = time.time()
@@ -271,12 +279,13 @@ def calc_pred_and_context_len(freq):
271279

272280

273281
df = pd.read_csv(csv_path)
274-
if fname in df["dataset"].values:
275-
df.loc[df["dataset"] == fname, "size_in_MB"] = round(fs,2)
276-
df.loc[df["dataset"] == fname, "eval_time"] = str(round(eval_time,2)) + unit
277-
df.loc[df["dataset"] == fname, list(metrics.keys())] = list(metrics.values())
282+
row_name = fname + ' (' + freq + ')'
283+
if row_name in df["dataset"].values:
284+
df.loc[df["dataset"] == row_name, "size_in_MB"] = round(fs,2)
285+
df.loc[df["dataset"] == row_name, "eval_time"] = str(round(eval_time,2)) + unit
286+
df.loc[df["dataset"] == row_name, list(metrics.keys())] = list(metrics.values())
278287
else:
279-
new_row = pd.DataFrame([{**{"dataset": fname, "size_in_MB":round(fs,2), "eval_time":str(round(eval_time,2)) + unit}, **metrics}])
288+
new_row = pd.DataFrame([{**{"dataset": row_name, "size_in_MB":round(fs,2), "eval_time":str(round(eval_time,2)) + unit}, **metrics}])
280289
df = pd.concat([df, new_row], ignore_index=True)
281290

282291
df.to_csv(csv_path, index=False)

leaderboard/moment.csv

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
dataset,size_in_MB,eval_time,mse,mae,mase,mape,rmse,nrmse,smape,msis,nd,mwsq,crps
2+
us_births,0.13,16.53s,0.0166749041527509,0.0132494084537029,0.6622852683067322,0.0336879082024097,0.1291313469409942,0.0200039861313061,0.2919767796993255,2393.22509765625,1.7653433739843682,,
3+
ett1,0.05,0.49s,0.0845161378383636,0.0771553590893745,0.818662166595459,0.1507003307342529,0.2907165884971618,0.0408028170847522,0.3720319271087646,0.0229524467140436,37.06777924404701,,
4+
ett2,0.05,0.48s,0.1525356918573379,0.0985602214932441,0.8158197999000549,0.159540593624115,0.3905581831932068,0.0520415722981005,0.3430469930171966,0.0324768535792827,-0.7515478683683955,,
5+
saugeenday,0.38,57.42s,0.0181519500911235,0.0106288585811853,0.8184567093849182,0.0984068512916565,0.1347291767597198,0.0080451322495082,0.2162542045116424,856.370361328125,65.20783161905682,,
6+
solar,0.35,0.69s,1.0479254722595217,0.6552554965019226,1.2421728372573853,1.644118309020996,1.0236823558807373,0.2602073408719523,1.2416865825653076,0.1769601106643676,-1.5701247442115474,,
7+
jena_weather,0.08,0.26s,0.4467118084430694,0.290574699640274,,-0.3770225942134857,0.6683650016784668,0.0439291477903074,0.626776397228241,0.0621072389185428,-6.0627986557802815,,
8+
hierarchical_sales,0.9,6.09s,0.9766348600387572,0.6037952899932861,1.2904167175292969,-0.0605682358145713,0.9882484078407288,0.0296427630370565,1.2487094402313232,0.1234115213155746,-14.845645767332355,,
9+
bizitobs_l2c,0.18,5.12s,0.1141704395413398,0.094519555568695,0.5653119683265686,0.0437845885753631,0.337891161441803,0.0411253008501759,0.2735987901687622,0.0017722472548484,-43.61411877196251,,
10+
M_DENSE,0.21,0.48s,0.65444016456604,0.4488910138607025,0.7423524856567383,0.2140763998031616,0.8089747428894043,0.0746180797828536,0.8750887513160706,0.0955005586147308,-12.196748869327852,,
11+
covid_deaths,0.27,1.13s,2532.86962890625,15.704353332519531,0.5370470285415649,0.5348979234695435,50.32762145996094,0.0688400358056701,0.8886440992355347,0.0817901268601417,0.8951871516715,,
12+
bizitobs_application,0.33,20.28s,0.031650137156248,0.0232451893389225,0.8000913262367249,-9.096251487731934,0.1779048591852188,0.0292253176385202,0.1688371747732162,2183.644775390625,-7.934246776719986,,
13+
hospital,0.35,2.76s,1.871889352798462,1.0410010814666748,0.5708366632461548,38.29072570800781,1.3681700229644775,0.0623355435952419,1.2547467947006226,0.1276404559612274,1.871985084498025,,
14+
car_parts_with_missing,0.58,9.95s,3.5560872554779053,0.8787520527839661,0.7763674855232239,2191.994140625,1.885759115219116,0.0130401857953503,1.6353957653045654,0.1159390807151794,10.414244728513925,,
15+
electricity,0.66,1.37s,231504363520.0,33604.7578125,0.5374543070793152,36700.98828125,481149.0,0.0324357633701603,1.2939711809158323,0.7841228246688843,0.8005288085222292,,
16+
kdd_cup_2018_with_missing,1.08,1.12s,0.8444492816925049,0.5074900984764099,0.636409342288971,-0.1625679284334182,0.9189392328262328,0.0264241123643264,0.993318736553192,0.1114950478076934,-4.843327060262989,,
17+
LOOP_SEATTLE,1.13,1.37s,1.1636507511138916,0.7680737972259521,0.962784230709076,-0.2038627117872238,1.0787264108657837,0.0758186484268039,1.0627912282943726,0.1333916038274765,-3.549600403287944,,
18+
us_births (M),0.0,0.26s,0.0490848012268543,0.025092938914895,,0.198994442820549,0.2215508967638015,0.0619808429064959,0.332131415605545,2407.629638671875,1.0232207619288771,,
19+
ett1 (W),0.01,0.26s,0.0661440715193748,0.0679000094532966,,0.1340852975845337,0.2571848928928375,0.0600888141994269,0.3639847040176391,0.0118538662791252,13.56712316983709,,
20+
ett2 (W),0.01,0.5s,0.1784712672233581,0.1021255776286125,,0.3991040885448456,0.4224585890769958,0.0620690916861361,0.4547058641910553,0.0119766816496849,-0.6570217851294525,,
21+
saugeenday (M),0.02,0.52s,0.0146793872117996,0.0103598264977335,1.1410484313964844,0.114974558353424,0.121158517897129,0.0245601202086958,0.2466829866170883,1111.4794921875,-15.320425227975118,,
22+
us_births (W),0.02,1.36s,0.0194058045744895,0.0143667459487915,1.4129470586776731,0.1352113783359527,0.1393047124147415,0.0343529098623171,0.2357538044452667,2704.560546875,0.7064739408163458,,
23+
ett1 (D),0.05,0.49s,0.0856200754642486,0.0776782855391502,0.8242107033729553,0.1468994617462158,0.2926090955734253,0.0410684352954762,0.3727368116378784,0.0231287106871604,37.319009002160186,,
24+
ett2 (D),0.05,0.57s,0.1532271057367324,0.0985804125666618,0.8159868717193604,0.1755074113607406,0.3914423286914825,0.0521593839939005,0.3719439208507538,0.0325935669243335,-0.751701830666341,,
25+
solar (W),0.06,0.69s,0.3245261907577514,0.3878287076950073,0.8207972049713135,-1.0662380456924438,0.5696719884872437,0.1840309984690182,0.8310236930847168,0.1476865410804748,-0.8024131291804086,,
26+
saugeenday (W),0.07,11.41s,0.0171893555670976,0.0098273605108261,0.7497754096984863,0.1031473726034164,0.1311081796884536,0.0117120778330479,0.20797760784626,1062.8035888671875,31.996620891779777,,
27+
jena_weather (D),0.08,0.5s,0.4404400885105133,0.2876468002796173,,-0.4747923910617828,0.6636565923690796,0.0436196815437357,0.6286620497703552,0.0604314506053924,-6.001708463378722,,
28+
us_births (D),0.13,16.44s,0.0165644250810146,0.0131850000470876,0.6590656042098999,0.0071459645405411,0.1287028491497039,0.0199376067116119,0.2973404228687286,2406.93896484375,1.7567616358453075,,
29+
hierarchical_sales (W),0.15,0.49s,1.111017107963562,0.6712689399719238,0.6610993146896362,-0.2276024222373962,1.05404794216156,0.0589132884568094,1.2139064073562622,0.1194732338190078,-4.8134139784308285,,
30+
bizitobs_l2c (H),0.18,5.16s,0.1132991909980773,0.0941098034381866,0.5628612637519836,0.042415402829647,0.3365994393825531,0.040968083188504,0.2739208936691284,0.0018358565866947,-43.42504701872018,,
31+
M_DENSE (D),0.21,0.54s,0.6656246781349182,0.4525275528430938,0.7483664155006409,0.2164407223463058,0.8158582448959351,0.0752529991130594,0.8902981877326965,0.0965064167976379,-12.29555671210528,,
32+
covid_deaths (D),0.27,1.13s,2528.41943359375,15.66344928741455,0.5356481671333313,0.5312768220901489,50.283390045166016,0.0687795343933266,0.8796680569648743,0.0815204679965972,0.8928555194893874,,
33+
bizitobs_application (10s),0.33,22.71s,0.0317411497235298,0.0232481099665164,0.8001918196678162,-9.959293365478516,0.1781604588031768,0.0292673062613005,0.1536118537187576,2209.325439453125,-7.935243670302386,,
34+
solar (D),0.35,1.33s,1.047087788581848,0.6567782163619995,1.2450594902038574,1.7552998065948486,1.0232731103897097,0.2601033157509053,1.2456032037734983,0.1776239722967147,-1.5737734890806443,,
35+
hospital (ME),0.35,2.86s,1.859989047050476,1.0373529195785522,0.5688363313674927,37.09786605834961,1.363814115524292,0.0621370829846621,1.2558231353759766,0.1271890103816986,1.8654247602467016,,
36+
saugeenday (D),0.38,64.48s,0.0180799681693315,0.0105487462133169,0.812287449836731,0.0973807573318481,0.1344617754220962,0.0080291648163463,0.2135202288627624,865.2394409179688,64.71634386854565,,
37+
car_parts_with_missing (ME),0.58,9.97s,3.562715530395508,0.875319242477417,0.7733346819877625,2163.47216796875,1.887515664100647,0.0130523324813129,1.6376537084579468,0.1151701137423515,10.37356189138631,,
38+
electricity (W),0.66,3.86s,230903316480.0,33545.01953125,0.5364989042282104,40271.640625,480524.0,0.0323936301596448,1.3060535192489624,0.7850792407989502,0.7991057298207234,,
39+
hierarchical_sales (D),0.9,16.82s,0.9756141304969788,0.6047087907791138,1.29236900806427,-0.0754419490694999,0.9877318143844604,0.0296272676845832,1.255674123764038,0.1236139833927154,-14.86810620930554,,
40+
kdd_cup_2018_with_missing (D),1.08,1.13s,0.8476442694664001,0.5114850997924805,0.6414192318916321,-0.1579349040985107,0.9206759929656982,0.0264740529299653,1.0037513971328735,0.1125412508845329,-4.88145410557481,,
41+
LOOP_SEATTLE (D),1.13,1.38s,1.1801847219467163,0.7724948525428772,0.9683259129524232,-0.2286671549081802,1.0863630771636963,0.0763553940847966,1.061790943145752,0.134036049246788,-3.5700320073767613,,
42+
SZ_TAXI (H),1.14,2.58s,0.1794196963310241,0.2704257071018219,0.742323100566864,0.9259967803955078,0.4235796332359314,0.0629424358185336,0.4413276612758636,0.0920828655362129,-0.4234020415293309,,
43+
ett1 (H),1.22,71.91s,0.0931190550327301,0.0700619220733642,1.2231614589691162,0.2108829170465469,0.3051541447639465,0.0278852932860295,0.3243643641471863,0.0397731065750122,-18.09391609509558,,
44+
ett2 (H),1.26,54.03s,0.0527420304715633,0.0486048087477684,1.0875681638717651,0.1189476400613784,0.2296563237905502,0.0182699874468558,0.3114534318447113,0.0579129941761493,-0.9500558201964512,,
45+
jena_weather (H),1.65,20.61s,0.269149512052536,0.1946270763874054,1.0476793050765991,-0.025795079767704,0.5187962055206299,0.0046744943853909,0.4797450304031372,0.0455239117145538,4.279154276658326,,
46+
bizitobs_l2c (5T),1.68,93.12s,0.1061218082904815,0.0827776566147804,1.6546341180801392,0.0960157662630081,0.3257634341716766,0.0395778664079169,0.196520447731018,0.0183309577405452,60.77264861275297,,
47+
restaurant (D),1.77,8.25s,32.19078826904297,1.542858362197876,0.5222368836402893,21.732288360595703,5.67369270324707,0.021389853140395,1.0622296333312988,0.163014754652977,0.8957319824794684,,
48+
m4_hourly (h),2.43,22.74s,0.6365521550178528,0.6075160503387451,2.4075798988342285,-0.0194074865430593,0.7978422045707703,0.0582948840714332,0.965739369392395,0.1226370185613632,1.5180986175834796,,
49+
bizitobs_service (10s),3.07,38.71s,0.6393813490867615,0.4555206298828125,0.7736144065856934,-0.4311994910240173,0.7996132373809814,0.0578584589820048,0.981988787651062,0.0674135088920593,66.96168597556668,,
50+
M_DENSE (H),3.7,42.05s,0.5782449841499329,0.444500982761383,1.7446072101593018,0.245728924870491,0.7604241967201233,0.0899721087789721,0.9053943753242492,0.0985398814082145,-111.408677600323,,
51+
ett1 (15T),4.4,312.68s,0.0770852044224739,0.0642813295125961,1.3816813230514526,-0.0450062677264213,0.2776422202587127,0.0248858562863315,0.2945191562175751,0.0444143638014793,-29.72643510740077,,
52+
ett2 (15T),4.57,170.67s,0.0352975837886333,0.0388589762151241,1.1798412799835205,0.1034263670444488,0.1878765076398849,0.0121796422919794,0.2683697342872619,0.0611041486263275,-0.8056754890308132,,
53+
SZ_TAXI (15T),4.58,17.77s,0.3003464341163635,0.3433499932289123,1.583962321281433,-1.0764042139053345,0.5480387210845947,0.0426858127010438,0.6180946826934814,0.0975749045610427,-0.734930536398494,,
54+
electricity (D),4.63,12.82s,265048064.0,909.3511962890624,1.0505276918411257,9888.9248046875,16280.296875,0.0079866265077432,0.6811089515686035,1.2647544145584106,0.153261294375977,,
55+
solar (H),5.97,60.66s,0.4952342212200165,0.4485998749732971,5.05915117263794,-0.4632841348648071,0.7037287950515747,0.1873204244722185,0.9461215734481812,0.1192889809608459,-3.141420588442726,,
56+
bitbrains_rnd (H),6.1,7.56s,2541.150634765625,4.16558313369751,0.6189665794372559,5294.056640625,50.40982818603516,0.0065061475422719,0.7933628559112549,0.073084145784378,1.014913167418809,,
57+
m4_weekly (W-SUN),7.18,29.96s,2.014073371887207,0.2677460014820099,1.0587129592895508,-0.0184317361563444,1.4191805124282837,0.0067748010896629,0.1633300334215164,0.0710709914565086,0.1223636577958742,,
58+
jena_weather (10T),7.18,129.14s,0.2310131639242172,0.1709593832492828,0.8379159569740295,0.4183250367641449,0.4806382954120636,0.0018715012395996,0.4648206830024719,0.0465143471956253,4.276199613918661,,
59+
kdd_cup_2018_with_missing (H),14.28,128.7s,1.903819441795349,0.616597056388855,4.51470422744751,-0.9062976241111756,1.3797895908355713,0.0370714838689682,1.1574369668960571,0.1389963328838348,10.823472138906274,,
60+
bitbrains_fast_storage (H),15.64,18.96s,10086.7998046875,2.826951742172241,0.5318763256072998,564.5944213867188,100.4330596923828,0.0093783960005972,0.8978911638259888,0.9046183824539183,0.6542457009370054,,
61+
LOOP_SEATTLE (H),27.05,121.61s,0.7220565676689148,0.6239529252052307,4.835772037506104,0.7766930460929871,0.849739134311676,0.0553899985984129,1.426215887069702,0.1485606729984283,6.506469472198424,,
62+
solar (10T),33.4,387.49s,0.3633464574813843,0.3939985632896423,0.9862849712371826,0.1683084964752197,0.6027822494506836,0.1532161214761041,0.8634337186813354,0.106142945587635,-1.932767178107501,,
63+
m4_yearly (YE-DEC),51.4,85.37s,0.2249934822320938,0.3888959884643554,1.2640678882598877,1.3605670928955078,0.4743347764015198,0.0156129467417141,1.4391971826553345,0.2767557203769684,1.1485716617412427,,
64+
bitbrains_rnd (5T),63.69,320.12s,473.8028564453125,2.5002923011779785,14.552663803100586,36466.94140625,21.767013549804688,0.0124533049345982,0.7882186770439148,0.0320991352200508,0.4796458303331513,,
65+
electricity (H),110.58,513.65s,428889.8125,42.35646438598633,38.55291366577149,2073.8056640625,654.8967895507812,0.0074876989582218,0.8307960629463196,1.041404128074646,0.1693215336592214,,
66+
temperature_rain_with_missing (D),113.99,239.13s,10.15415382385254,0.8827998638153076,1.944754958152771,8.244161605834961,3.186558246612549,0.0034670636363697,1.4160765409469604,0.1196500360965728,2.4729034145101805,,
67+
bitbrains_fast_storage (5T),160.06,801.34s,1755.422607421875,2.4329330921173096,13.81318473815918,2202.016357421875,41.897762298583984,0.0038607244908998,0.8529143929481506,0.8787821531295776,0.5712061641206433,,
68+
m4_quarterly (QE-DEC),163.93,267.8s,0.0343928597867488,0.0017406134866178,0.0162255093455314,0.0002408254076726,0.1854531168937683,0.0008958026187081,0.0006573513965122,0.0275926832109689,0.0059233256104646,,
69+
m4_daily (D),316.28,26.11m,0.0361730828881263,0.0036603524349629,0.2650297284126282,0.0032658216077834,0.1901922225952148,0.0001379508480314,0.0050389510579407,0.2410273253917694,0.0056191225238017,,
70+
LOOP_SEATTLE (5T),324.08,25.98m,0.6140819191932678,0.4618953168392181,2.7076990604400635,0.4295462965965271,0.7836337685585022,0.0466134961183778,1.013070583343506,0.1044607460498809,1.694795358113825,,
71+
electricity (15T),442.39,34.73m,44896.16796875,12.45554256439209,103.10625457763672,513.2125244140625,211.8871612548828,0.0113598928018075,0.8586267232894897,1.054625153541565,0.2306723787020087,,
72+
m4_monthly (ME),1025.34,73.67m,0.00413757236674428,0.0005038601229898632,0.042170461267232895,0.000296300946502015,0.06432396173477173,0.0029465290793323916,0.0004989044973626733,0.010481921955943108,0.0043333082261389575,,

0 commit comments

Comments
 (0)