Skip to content

Commit 9175251

Browse files
committed
update tutorials
1 parent 1a1f0d8 commit 9175251

File tree

30 files changed

+1522
-5230
lines changed

30 files changed

+1522
-5230
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,9 @@ cython_debug/
139139

140140
# Lightning logs dir
141141
**lightning_logs
142+
143+
# Tutorial logs dir
144+
**tutorial_logs
145+
146+
# tmp dir
147+
**tmp*

pina/problem/zoo/supervised_problem.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from ..abstract_problem import AbstractProblem
44
from ... import Condition
5-
from ... import Graph
5+
from ... import LabelTensor
66

77

88
class SupervisedProblem(AbstractProblem):
@@ -22,16 +22,20 @@ class SupervisedProblem(AbstractProblem):
2222

2323
conditions = {}
2424
output_variables = None
25+
input_variables = None
2526

26-
def __init__(self, input_, output_):
27+
def __init__(self, input_, output_, input_variables=None, output_variables=None):
2728
"""
2829
Initialize the SupervisedProblem class.
2930
3031
:param input_: Input data of the problem.
32+
:type input_: torch.Tensor | LabelTensor | Graph | Data
3133
:param output_: Output data of the problem.
32-
:type output_: torch.Tensor | Graph
34+
:type output_: torch.Tensor | LabelTensor | Graph | Data
3335
"""
34-
if isinstance(input_, Graph):
35-
input_ = input_.data
36+
# Set input and output variables
37+
self.input_variables = input_variables
38+
self.output_variables = output_variables
39+
# Set the condition
3640
self.conditions["data"] = Condition(input=input_, target=output_)
3741
super().__init__()

pina/trainer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,6 @@ def __init__(
127127

128128
# logging
129129
self.logging_kwargs = {
130-
"logger": bool(
131-
kwargs["logger"] is not None or kwargs["logger"] is True
132-
),
133130
"sync_dist": bool(
134131
len(self._accelerator_connector._parallel_devices) > 1
135132
),

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ dev = [
3939
]
4040
tutorials = [
4141
"smithers @ git+https://github.com/mathLab/smithers.git",
42+
"torchvision",
43+
"tensorboard",
44+
"scipy",
45+
"numpy",
4246
]
4347

4448
[project.urls]

tutorials/tutorial1/tutorial.ipynb

Lines changed: 151 additions & 115 deletions
Large diffs are not rendered by default.

tutorials/tutorial1/tutorial.py

Lines changed: 0 additions & 322 deletions
This file was deleted.

0 commit comments

Comments
 (0)