File tree Expand file tree Collapse file tree 2 files changed +6
-3
lines changed
Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -44,8 +44,9 @@ phot = "PSF"
4444mode = " forced"
4545
4646[data_set .LSDBDataGenerator .dask ]
47- cluster_type = " none "
47+ cluster_type = " LocalCluster " # none or LocalCLuster
4848
49+ # kwargs to pass to `dask.distributed.CLient(**kwargs)`
4950[data_set .LSDBDataGenerator .dask .LocalCluster ]
5051n_workers = 8
5152memory_limit = " 8GB"
Original file line number Diff line number Diff line change @@ -133,9 +133,11 @@ class LinearModel(torch.nn.Module):
133133 err_scaler_lg_upper = 5.0
134134 err_scaler_lg_interval = err_scaler_lg_upper - err_scaler_lg_lower
135135
136- def __init__ (self , config , shape ) -> None :
136+ def __init__ (self , config , data_sample ) -> None :
137137 super ().__init__ ()
138- self .layers = torch .nn .Linear (shape , 1 )
138+ self .config = config
139+ input_d = len (data_sample ["data" ])
140+ self .layers = torch .nn .Linear (input_d , 1 )
139141 self .loss = import_by_name (config ["model" ]["loss_fn" ])
140142 self .n_src = config ["data_set" ]["LSDBDataGenerator" ]["n_src" ]
141143 if not isinstance (self .n_src , int ):
You can’t perform that action at this time.
0 commit comments