3939
4040# To setup the test pack, if not already, run `atomworks setup tests`
4141dataset = FileDataset .from_directory (
42- directory = "../../tests/data/ml/af2_distillation/cif" , name = "example_directory_dataset"
42+ directory = "../../tests/data/ml/af2_distillation/cif" ,
43+ name = "example_directory_dataset" ,
4344)
4445
4546########################################################################
5859# Understanding Dataset Requirements
5960# ----------------------------------
6061#
61- # At a high level, to train models with AtomWorks, we need typically need a Dataset that:
62+ # At a high level, to train models with AtomWorks, we typically need a Dataset that:
6263#
6364# (1) Takes as input an item index and returns the corresponding example information; typically includes:
6465# a. Path to a structural file saved on disk (`/path/to/dataset/my_dataset_0.cif`)
@@ -84,7 +85,9 @@ def simple_loading_fn(raw_data) -> dict:
8485
8586
8687dataset_with_loading_fn = FileDataset .from_directory (
87- directory = "../../tests/data/pdb" , name = "example_pdb_dataset" , loader = simple_loading_fn
88+ directory = "../../tests/data/pdb" ,
89+ name = "example_pdb_dataset" ,
90+ loader = simple_loading_fn ,
8891)
8992output = dataset_with_loading_fn [1 ]
9093print (f"Output AtomArray has { len (output ['atom_array' ])} atoms!" )
@@ -120,7 +123,10 @@ def simple_loading_fn(raw_data) -> dict:
120123# Just like with the loading function, we can also pass a composed `Transform` pipeline to our datasets.
121124
122125dataset_with_loading_fn_and_transforms = FileDataset .from_directory (
123- directory = "../../tests/data/pdb" , name = "example_pdb_dataset" , loader = simple_loading_fn , transform = pipe
126+ directory = "../../tests/data/pdb" ,
127+ name = "example_pdb_dataset" ,
128+ loader = simple_loading_fn ,
129+ transform = pipe ,
124130)
125131
126132########################################################################
@@ -223,7 +229,9 @@ def simple_loading_fn(raw_data) -> dict:
223229 data = interfaces_df ,
224230 name = "interfaces_dataset" ,
225231 # We use a pre-built loader that takes in a list of column names and returns a loader function
226- loader = create_loader_with_query_pn_units (pn_unit_iid_colnames = ["pn_unit_1_iid" , "pn_unit_2_iid" ]),
232+ loader = create_loader_with_query_pn_units (
233+ pn_unit_iid_colnames = ["pn_unit_1_iid" , "pn_unit_2_iid" ]
234+ ),
227235 transform = pipe ,
228236)
229237
0 commit comments