Skip to content

Commit 787d8a9

Browse files
committed
Add fix for auto models
1 parent 14af4d0 commit 787d8a9

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

neuralforecast/core.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,21 @@ def fit(
537537
raise ValueError(
538538
"All entries in the list of files must be of type string"
539539
)
540-
max_input_size = max(m.input_size for m in self.models)
540+
max_input_size = 0
541+
for m in self.models:
542+
if hasattr(m, 'config') and isinstance(m.config, dict):
543+
input_sizes = m.config['input_size']
544+
if hasattr(input_sizes, 'categories'):
545+
max_input_size = max(max_input_size, max(input_sizes.categories))
546+
elif isinstance(input_sizes, (list, tuple)):
547+
max_input_size = max(max_input_size, max(input_sizes))
548+
else:
549+
max_input_size = max(max_input_size, input_sizes)
550+
else:
551+
max_input_size = max(max_input_size, m.input_size)
541552
max_h = max(m.h for m in self.models)
553+
if max_input_size == 0:
554+
max_input_size = 2*max_h
542555
max_size_limit = 2 * (max_input_size + max_h + val_size)
543556
self.dataset = self._prepare_fit_for_local_files(
544557
files_list=df,

0 commit comments

Comments
 (0)