-
Notifications
You must be signed in to change notification settings - Fork 92
Graph condition #471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Graph condition #471
Conversation
|
@FilippoOlivo I rebased on 0.2, cuz the history was a bit unclean |
8696366 to
39b4252
Compare
3ad2c8a to
f8f966c
Compare
0115673 to
ca3de7c
Compare
|
@dario-coscia, I moved the extract logic of |
|
@FilippoOlivo ok perfect! I will review the PR. Can you first clean the github history? |
4a2bc6a to
9bfd926
Compare
dario-coscia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still missing the tests to review fully, but there are some questions on the new graph structure dataset and condition first to understand. Besides that I think it is a very great job, let's simplify a bit and then it is ready to go!
| "InputOutputPointsCondition", | ||
| "GraphInputOutputCondition", | ||
| "GraphDataCondition", | ||
| "GraphInputEquationCondition", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this really useful? I think this is useful only if you want to differentiate wrt graph tensors (node features, or others). Is it correct?
| # The split between GraphInputOutputCondition and GraphDataCondition | ||
| # distinguishes different types of graph conditions passed to problems. | ||
| # This separation simplifies consistency checks during problem creation. | ||
| class GraphDataCondition(GraphCondition): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the difference between graphDatacondition and GraphCondition?
|
|
||
|
|
||
| class GraphInputOutputCondition(GraphCondition): | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not implemented?
| if k != "equation" | ||
| # Equations are NEVER dataloaded | ||
| } | ||
| if offset + stage_len >= len_condition: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why now = is off? It was put in #465 for solving a bug
| return wrapper | ||
|
|
||
| @staticmethod | ||
| def _extract_cond_vars(func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand _extract_output and why it exists, but what is the purpose of _extract_cond_vars?
| :param torch.Tensor pos: The position tensor. | ||
| """ | ||
| if pos is not None: | ||
| check_consistency(pos, (torch.Tensor, LabelTensor)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why this change
| edge_index = to_undirected(edge_index) | ||
| return edge_index | ||
|
|
||
| def extract(self, labels): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a better option would be def extract(self, labels, attr='x') where attr is where you want to extract the label from.
|
|
||
| def __new__(cls, *args, **kwargs): | ||
|
|
||
| if sorted(list(kwargs.keys())) == sorted(["input_", "output_"]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would start writing input, target. And for graph make another problem. supervised problems are defined on an input and on a target, having only one value is misleading
| y = y[0] | ||
| edge_index = graph[0].edge_index | ||
| graph = Data(x=x, pos=pos, edge_index=edge_index, y=y) | ||
| condition = Condition(graph=graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this, is very complicated to understand from a user perspective. If there is an input and a target we must specify that in the condition, otherwise the problem is unsupervised and you just use the data condition
9bfd926 to
78e6482
Compare
78e6482 to
91bd0c6
Compare
GiovanniCanali
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for implementing the graph condition @FilippoOlivo!
I just noticed a print statement. Remove it if not needed.
| def __new__(cls, conditions_dict, **kwargs): | ||
| if len(conditions_dict) == 0: | ||
| raise ValueError("No conditions provided") | ||
| print(conditions_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I came across this print. Is it intentional?
|
closing, moving to #475 |
No description provided.