@@ -81,35 +81,35 @@ def mtl_backward(
8181
8282 check_optional_positive_chunk_size (parallel_chunk_size )
8383
84- losses = as_checked_ordered_set (losses , "losses" )
85- features = as_checked_ordered_set (features , "features" )
84+ losses_ = as_checked_ordered_set (losses , "losses" )
85+ features_ = as_checked_ordered_set (features , "features" )
8686
8787 if shared_params is None :
88- shared_params = get_leaf_tensors (tensors = features , excluded = [])
88+ shared_params_ = get_leaf_tensors (tensors = features_ , excluded = [])
8989 else :
90- shared_params = OrderedSet (shared_params )
90+ shared_params_ = OrderedSet (shared_params )
9191 if tasks_params is None :
92- tasks_params = [get_leaf_tensors (tensors = [loss ], excluded = features ) for loss in losses ]
92+ tasks_params_ = [get_leaf_tensors (tensors = [loss ], excluded = features_ ) for loss in losses_ ]
9393 else :
94- tasks_params = [OrderedSet (task_params ) for task_params in tasks_params ]
94+ tasks_params_ = [OrderedSet (task_params ) for task_params in tasks_params ]
9595
96- if len (features ) == 0 :
96+ if len (features_ ) == 0 :
9797 raise ValueError ("`features` cannot be empty." )
9898
99- _check_no_overlap (shared_params , tasks_params )
100- _check_losses_are_scalar (losses )
99+ _check_no_overlap (shared_params_ , tasks_params_ )
100+ _check_losses_are_scalar (losses_ )
101101
102- if len (losses ) == 0 :
102+ if len (losses_ ) == 0 :
103103 raise ValueError ("`losses` cannot be empty" )
104- if len (losses ) != len (tasks_params ):
104+ if len (losses_ ) != len (tasks_params_ ):
105105 raise ValueError ("`losses` and `tasks_params` should have the same size." )
106106
107107 backward_transform = _create_transform (
108- losses = losses ,
109- features = features ,
108+ losses = losses_ ,
109+ features = features_ ,
110110 aggregator = aggregator ,
111- tasks_params = tasks_params ,
112- shared_params = shared_params ,
111+ tasks_params = tasks_params_ ,
112+ shared_params = shared_params_ ,
113113 retain_graph = retain_graph ,
114114 parallel_chunk_size = parallel_chunk_size ,
115115 )
0 commit comments