Skip to content

Commit ed1d4b4

Browse files
Charlotte TumescheitCharlotte Tumescheit
authored andcommitted
adjust to current dev branch
1 parent af7df07 commit ed1d4b4

File tree

5 files changed

+105
-26
lines changed

5 files changed

+105
-26
lines changed

chebai/cli.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,18 @@ def call_data_methods(data: Type[XYBaseDataModule]):
4646
data.setup()
4747
return data.num_of_labels
4848

49-
# parser.link_arguments(
50-
# "data",
51-
# "model.init_args.out_dim",
52-
# apply_on="instantiate",
53-
# compute_fn=call_data_methods,
54-
# )
49+
parser.link_arguments(
50+
"data",
51+
"model.init_args.out_dim",
52+
apply_on="instantiate",
53+
compute_fn=call_data_methods,
54+
)
5555

56-
# parser.link_arguments(
57-
# "data.feature_vector_size",
58-
# "model.init_args.input_dim",
59-
# apply_on="instantiate",
60-
# )
56+
parser.link_arguments(
57+
"data.feature_vector_size",
58+
"model.init_args.input_dim",
59+
apply_on="instantiate",
60+
)
6161

6262
for kind in ("train", "val", "test"):
6363
for average in ("micro-f1", "macro-f1", "balanced-accuracy", "f1", "mse", "rmse","r2"):
@@ -66,10 +66,14 @@ def call_data_methods(data: Type[XYBaseDataModule]):
6666
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
6767
apply_on="instantiate",
6868
)
69+
6970
parser.link_arguments(
70-
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
71+
"data.num_of_labels", "trainer.callbacks.init_args.num_labels"
7172
)
7273
# parser.link_arguments(
74+
# "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
75+
# )
76+
# parser.link_arguments(
7377
# "data", "model.init_args.criterion.init_args.data_extractor"
7478
# )
7579
# parser.link_arguments(

chebai/models/electra.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
2121

22-
from chebai.loss.semantic import DisjointLoss as ElectraChEBIDisjointLoss # noqa
22+
# TODO: put back in before pull request
23+
# from chebai.loss.semantic import DisjointLoss as ElectraChEBIDisjointLoss # noqa
2324

2425

2526
class ElectraPre(ChebaiBaseNet):
@@ -40,6 +41,7 @@ class ElectraPre(ChebaiBaseNet):
4041

4142
def __init__(self, config: Dict[str, Any] = None, **kwargs: Any):
4243
super().__init__(config=config, **kwargs)
44+
4345
self.generator_config = ElectraConfig(**config["generator"])
4446
self.generator = ElectraForMaskedLM(self.generator_config)
4547
self.discriminator_config = ElectraConfig(**config["discriminator"])

chebai/preprocessing/datasets/tox21.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def _load_dict(self, input_file_path: str) -> List[Dict]:
163163
features=smiles, labels=labels, ident=row["mol_id"], group=group
164164
)
165165
# yield self.reader.to_data(dict(features=smiles, labels=labels, ident=row["mol_id"]))
166-
def _set_processed_data_props(self):
166+
def _set_processed_data_props(self):
167167
"""
168168
Load processed data and extract metadata.
169169

chebai/result/molplot.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
from networkx.algorithms.isomorphism import GraphMatcher
1212
from pysmiles.read_smiles import LOGGER, TokenType, _tokenize
1313
from rdkit import Chem
14-
from rdkit.Chem.Draw import MolToMPL, rdMolDraw2D
14+
from rdkit.Chem.Draw import rdMolDraw2D
15+
# from rdkit.Chem.Draw import MolToMPL, rdMolDraw2D
1516

16-
from chebai.preprocessing.datasets import JCI_500_COLUMNS_INT
17+
# from chebai.preprocessing.datasets import JCI_500_COLUMNS_INT
1718
from chebai.result.base import ResultProcessor
1819

1920

tutorials/demo_process_results.ipynb

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
"end_time": "2023-11-29T08:17:25.832642900Z",
99
"start_time": "2023-11-29T08:17:25.816890700Z"
1010
},
11-
"collapsed": true
11+
"collapsed": true,
12+
"jupyter": {
13+
"outputs_hidden": true
14+
}
1215
},
1316
"outputs": [],
1417
"source": [
@@ -37,7 +40,10 @@
3740
"end_time": "2023-11-24T09:13:26.387885900Z",
3841
"start_time": "2023-11-24T09:06:23.191727Z"
3942
},
40-
"collapsed": false
43+
"collapsed": false,
44+
"jupyter": {
45+
"outputs_hidden": false
46+
}
4147
},
4248
"outputs": [
4349
{
@@ -109,7 +115,10 @@
109115
"end_time": "2023-11-29T08:33:48.374202Z",
110116
"start_time": "2023-11-29T08:33:48.261436600Z"
111117
},
112-
"collapsed": false
118+
"collapsed": false,
119+
"jupyter": {
120+
"outputs_hidden": false
121+
}
113122
},
114123
"outputs": [
115124
{
@@ -239,7 +248,10 @@
239248
"end_time": "2023-11-24T09:55:24.187152800Z",
240249
"start_time": "2023-11-24T09:55:21.580572700Z"
241250
},
242-
"collapsed": false
251+
"collapsed": false,
252+
"jupyter": {
253+
"outputs_hidden": false
254+
}
243255
},
244256
"outputs": [
245257
{
@@ -275,6 +287,9 @@
275287
"execution_count": 2,
276288
"metadata": {
277289
"collapsed": false,
290+
"jupyter": {
291+
"outputs_hidden": false
292+
},
278293
"pycharm": {
279294
"name": "#%%\n"
280295
}
@@ -299,6 +314,9 @@
299314
"execution_count": 4,
300315
"metadata": {
301316
"collapsed": false,
317+
"jupyter": {
318+
"outputs_hidden": false
319+
},
302320
"pycharm": {
303321
"name": "#%%\n"
304322
}
@@ -338,6 +356,9 @@
338356
"execution_count": 7,
339357
"metadata": {
340358
"collapsed": false,
359+
"jupyter": {
360+
"outputs_hidden": false
361+
},
341362
"pycharm": {
342363
"name": "#%%\n"
343364
}
@@ -360,6 +381,9 @@
360381
"execution_count": 3,
361382
"metadata": {
362383
"collapsed": false,
384+
"jupyter": {
385+
"outputs_hidden": false
386+
},
363387
"pycharm": {
364388
"name": "#%%\n"
365389
}
@@ -382,6 +406,9 @@
382406
"execution_count": 4,
383407
"metadata": {
384408
"collapsed": false,
409+
"jupyter": {
410+
"outputs_hidden": false
411+
},
385412
"pycharm": {
386413
"name": "#%%\n"
387414
}
@@ -403,6 +430,9 @@
403430
"execution_count": 9,
404431
"metadata": {
405432
"collapsed": false,
433+
"jupyter": {
434+
"outputs_hidden": false
435+
},
406436
"pycharm": {
407437
"name": "#%%\n"
408438
}
@@ -428,6 +458,9 @@
428458
"execution_count": 11,
429459
"metadata": {
430460
"collapsed": false,
461+
"jupyter": {
462+
"outputs_hidden": false
463+
},
431464
"pycharm": {
432465
"name": "#%%\n"
433466
}
@@ -451,6 +484,9 @@
451484
"execution_count": 5,
452485
"metadata": {
453486
"collapsed": false,
487+
"jupyter": {
488+
"outputs_hidden": false
489+
},
454490
"pycharm": {
455491
"name": "#%%\n"
456492
}
@@ -483,6 +519,9 @@
483519
"execution_count": 58,
484520
"metadata": {
485521
"collapsed": false,
522+
"jupyter": {
523+
"outputs_hidden": false
524+
},
486525
"pycharm": {
487526
"name": "#%%\n"
488527
}
@@ -643,6 +682,9 @@
643682
"execution_count": 12,
644683
"metadata": {
645684
"collapsed": false,
685+
"jupyter": {
686+
"outputs_hidden": false
687+
},
646688
"pycharm": {
647689
"name": "#%%\n"
648690
}
@@ -700,6 +742,9 @@
700742
"execution_count": 11,
701743
"metadata": {
702744
"collapsed": false,
745+
"jupyter": {
746+
"outputs_hidden": false
747+
},
703748
"pycharm": {
704749
"name": "#%%\n"
705750
}
@@ -730,7 +775,10 @@
730775
{
731776
"cell_type": "markdown",
732777
"metadata": {
733-
"collapsed": false
778+
"collapsed": false,
779+
"jupyter": {
780+
"outputs_hidden": false
781+
}
734782
},
735783
"source": [
736784
"Results:\n",
@@ -762,6 +810,9 @@
762810
"execution_count": 40,
763811
"metadata": {
764812
"collapsed": false,
813+
"jupyter": {
814+
"outputs_hidden": false
815+
},
765816
"pycharm": {
766817
"name": "#%%\n"
767818
}
@@ -794,6 +845,9 @@
794845
"execution_count": 41,
795846
"metadata": {
796847
"collapsed": false,
848+
"jupyter": {
849+
"outputs_hidden": false
850+
},
797851
"pycharm": {
798852
"name": "#%%\n"
799853
}
@@ -826,6 +880,9 @@
826880
"execution_count": 42,
827881
"metadata": {
828882
"collapsed": false,
883+
"jupyter": {
884+
"outputs_hidden": false
885+
},
829886
"pycharm": {
830887
"name": "#%%\n"
831888
}
@@ -858,6 +915,9 @@
858915
"execution_count": 13,
859916
"metadata": {
860917
"collapsed": false,
918+
"jupyter": {
919+
"outputs_hidden": false
920+
},
861921
"pycharm": {
862922
"name": "#%%\n"
863923
}
@@ -912,6 +972,9 @@
912972
"start_time": "2023-11-24T07:36:43.594504200Z"
913973
},
914974
"collapsed": false,
975+
"jupyter": {
976+
"outputs_hidden": false
977+
},
915978
"pycharm": {
916979
"name": "#%%\n"
917980
}
@@ -958,6 +1021,9 @@
9581021
"start_time": "2023-11-24T07:36:51.800819200Z"
9591022
},
9601023
"collapsed": false,
1024+
"jupyter": {
1025+
"outputs_hidden": false
1026+
},
9611027
"pycharm": {
9621028
"name": "#%%\n"
9631029
}
@@ -984,6 +1050,9 @@
9841050
"execution_count": null,
9851051
"metadata": {
9861052
"collapsed": false,
1053+
"jupyter": {
1054+
"outputs_hidden": false
1055+
},
9871056
"pycharm": {
9881057
"name": "#%%\n"
9891058
}
@@ -1010,6 +1079,9 @@
10101079
"execution_count": null,
10111080
"metadata": {
10121081
"collapsed": false,
1082+
"jupyter": {
1083+
"outputs_hidden": false
1084+
},
10131085
"pycharm": {
10141086
"name": "#%%\n"
10151087
}
@@ -1035,23 +1107,23 @@
10351107
],
10361108
"metadata": {
10371109
"kernelspec": {
1038-
"display_name": "Python 3",
1110+
"display_name": "Python 3 (ipykernel)",
10391111
"language": "python",
10401112
"name": "python3"
10411113
},
10421114
"language_info": {
10431115
"codemirror_mode": {
10441116
"name": "ipython",
1045-
"version": 2
1117+
"version": 3
10461118
},
10471119
"file_extension": ".py",
10481120
"mimetype": "text/x-python",
10491121
"name": "python",
10501122
"nbconvert_exporter": "python",
1051-
"pygments_lexer": "ipython2",
1052-
"version": "2.7.6"
1123+
"pygments_lexer": "ipython3",
1124+
"version": "3.12.11"
10531125
}
10541126
},
10551127
"nbformat": 4,
1056-
"nbformat_minor": 0
1128+
"nbformat_minor": 4
10571129
}

0 commit comments

Comments
 (0)