File tree Expand file tree Collapse file tree 1 file changed +14
-1
lines changed
Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -537,8 +537,21 @@ def fit(
537537 raise ValueError (
538538 "All entries in the list of files must be of type string"
539539 )
540- max_input_size = max (m .input_size for m in self .models )
540+ max_input_size = 0
541+ for m in self .models :
542+ if hasattr (m , 'config' ) and isinstance (m .config , dict ):
543+ input_sizes = m .config ['input_size' ]
544+ if hasattr (input_sizes , 'categories' ):
545+ max_input_size = max (max_input_size , max (input_sizes .categories ))
546+ elif isinstance (input_sizes , (list , tuple )):
547+ max_input_size = max (max_input_size , max (input_sizes ))
548+ else :
549+ max_input_size = max (max_input_size , input_sizes )
550+ else :
551+ max_input_size = max (max_input_size , m .input_size )
541552 max_h = max (m .h for m in self .models )
553+ if max_input_size == 0 :
554+ max_input_size = 2 * max_h
542555 max_size_limit = 2 * (max_input_size + max_h + val_size )
543556 self .dataset = self ._prepare_fit_for_local_files (
544557 files_list = df ,
You can’t perform that action at this time.
0 commit comments