-
-
Notifications
You must be signed in to change notification settings - Fork 7
Fix Issue 401 Error #448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ooples
wants to merge
12
commits into
master
Choose a base branch
from
claude/fix-issue-401-011CUw2PdjV8WHcpjT2wAkpr
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+9,745
−7
Open
Fix Issue 401 Error #448
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
705c5d6
Expand Graph Neural Networks beyond GCN (Issue #401)
claude 18de217
fix: resolve 22 build errors in graph neural network layers
ooples e405284
refactor: address 5 pr review comments with production-ready code
ooples 23e7580
fix: implement proper gradient computation for graphattentionlayer tr…
ooples 541725b
fix: exclude graphconvolutionallayer from lora wrapping to prevent ap…
ooples 1f3bbd6
fix: apply dropout to attention coefficients and clear cached transfo…
ooples 4829b48
fix: only apply dropout during training mode to prevent inference deg…
ooples b9f9fdd
fix: resolve merge conflict with master - add both igraphconvolutionl…
ooples ded3d8f
refactor: extract sageaggregatortype enum to its own file
ooples ff2bf25
Add Message Passing Framework and Advanced GNN Architectures (Issue #…
claude 93b15fa
Complete Advanced GNN Architectures (Issue #401 - Phase 1 Complete)
claude 29b8ca2
Add Graph Learning Task Frameworks and Dataset Support (Issue #401 - …
claude File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| using AiDotNet.LinearAlgebra; | ||
|
|
||
| namespace AiDotNet.Data.Abstractions; | ||
|
|
||
| /// <summary> | ||
| /// Represents a graph classification task where the goal is to classify entire graphs. | ||
| /// </summary> | ||
| /// <typeparam name="T">The numeric type used for calculations, typically float or double.</typeparam> | ||
| /// <remarks> | ||
| /// <para> | ||
| /// Graph classification assigns a label to an entire graph based on its structure and node/edge features. | ||
| /// Unlike node classification (classify individual nodes) or link prediction (predict edges), | ||
| /// graph classification treats the whole graph as a single data point. | ||
| /// </para> | ||
| /// <para><b>For Beginners:</b> Graph classification is like determining the category of a complex object. | ||
| /// | ||
| /// **Real-world examples:** | ||
| /// | ||
| /// **Molecular Property Prediction:** | ||
| /// - Input: Molecular graph (atoms as nodes, bonds as edges) | ||
| /// - Task: Predict molecular properties | ||
| /// - Examples: | ||
| /// * Is this molecule toxic? | ||
| /// * What is the solubility? | ||
| /// * Will this be a good drug candidate? | ||
| /// - Dataset: ZINC, QM9, BACE | ||
| /// | ||
| /// **Protein Function Prediction:** | ||
| /// - Input: Protein structure graph | ||
| /// - Task: Predict protein function or family | ||
| /// - How: Analyze amino acid sequences and 3D structure | ||
| /// | ||
| /// **Chemical Reaction Prediction:** | ||
| /// - Input: Reaction graph showing reactants and products | ||
| /// - Task: Predict reaction type or outcome | ||
| /// | ||
| /// **Social Network Analysis:** | ||
| /// - Input: Community subgraphs | ||
| /// - Task: Classify community type or behavior | ||
| /// - Example: Identify bot networks vs organic communities | ||
| /// | ||
| /// **Code Analysis:** | ||
| /// - Input: Abstract syntax tree (AST) or control flow graph | ||
| /// - Task: Detect bugs, classify code functionality | ||
| /// - Example: "Is this code snippet vulnerable to SQL injection?" | ||
| /// | ||
| /// **Key Challenge:** Graph-level representation | ||
| /// - Must aggregate information from all nodes and edges | ||
| /// - Common approaches: Global pooling, hierarchical pooling, set2set | ||
| /// </para> | ||
| /// </remarks> | ||
| public class GraphClassificationTask<T> | ||
| { | ||
| /// <summary> | ||
| /// List of training graphs. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// Each graph in the list is an independent sample with its own structure and features. | ||
| /// </remarks> | ||
| public List<GraphData<T>> TrainGraphs { get; set; } = new List<GraphData<T>>(); | ||
|
|
||
| /// <summary> | ||
| /// List of validation graphs. | ||
| /// </summary> | ||
| public List<GraphData<T>> ValGraphs { get; set; } = new List<GraphData<T>>(); | ||
|
|
||
| /// <summary> | ||
| /// List of test graphs. | ||
| /// </summary> | ||
| public List<GraphData<T>> TestGraphs { get; set; } = new List<GraphData<T>>(); | ||
|
|
||
| /// <summary> | ||
| /// Labels for training graphs. | ||
| /// Shape: [num_train_graphs] or [num_train_graphs, num_classes] for multi-label. | ||
| /// </summary> | ||
| public Tensor<T> TrainLabels { get; set; } = new Tensor<T>([0]); | ||
|
|
||
| /// <summary> | ||
| /// Labels for validation graphs. | ||
| /// Shape: [num_val_graphs] or [num_val_graphs, num_classes]. | ||
| /// </summary> | ||
| public Tensor<T> ValLabels { get; set; } = new Tensor<T>([0]); | ||
|
|
||
| /// <summary> | ||
| /// Labels for test graphs. | ||
| /// Shape: [num_test_graphs] or [num_test_graphs, num_classes]. | ||
| /// </summary> | ||
| public Tensor<T> TestLabels { get; set; } = new Tensor<T>([0]); | ||
|
|
||
| /// <summary> | ||
| /// Number of classes in the classification task. | ||
| /// </summary> | ||
| public int NumClasses { get; set; } | ||
|
|
||
| /// <summary> | ||
| /// Whether this is a multi-label classification task. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// - False: Each graph has exactly one label (e.g., molecule is toxic or not) | ||
| /// - True: Each graph can have multiple labels (e.g., molecule has multiple properties) | ||
| /// </remarks> | ||
| public bool IsMultiLabel { get; set; } = false; | ||
|
|
||
| /// <summary> | ||
| /// Whether this is a regression task instead of classification. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// <para> | ||
| /// For regression tasks (e.g., predicting molecular energy), labels are continuous values | ||
| /// rather than discrete classes. | ||
| /// </para> | ||
| /// <para><b>For Beginners:</b> The difference between classification and regression: | ||
| /// - **Classification**: Predict categories (e.g., "toxic" vs "non-toxic") | ||
| /// - **Regression**: Predict continuous values (e.g., "solubility = 2.3 mg/L") | ||
| /// | ||
| /// Examples: | ||
| /// - Classification: Is this molecule a good drug? (Yes/No) | ||
| /// - Regression: What is this molecule's binding affinity? (0.0 to 10.0) | ||
| /// </para> | ||
| /// </remarks> | ||
| public bool IsRegression { get; set; } = false; | ||
|
|
||
| /// <summary> | ||
| /// Average number of nodes per graph (for informational purposes). | ||
| /// </summary> | ||
| public double AvgNumNodes { get; set; } | ||
|
|
||
| /// <summary> | ||
| /// Average number of edges per graph (for informational purposes). | ||
| /// </summary> | ||
| public double AvgNumEdges { get; set; } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| using AiDotNet.LinearAlgebra; | ||
|
|
||
| namespace AiDotNet.Data.Abstractions; | ||
|
|
||
| /// <summary> | ||
| /// Represents a single graph with nodes, edges, features, and optional labels. | ||
| /// </summary> | ||
| /// <typeparam name="T">The numeric type used for calculations, typically float or double.</typeparam> | ||
| /// <remarks> | ||
| /// <para> | ||
| /// GraphData encapsulates all information about a graph structure including: | ||
| /// - Node features (attributes for each node) | ||
| /// - Edge indices (connections between nodes) | ||
| /// - Edge features (optional attributes for edges) | ||
| /// - Adjacency matrix (graph structure in matrix form) | ||
| /// - Labels (for supervised learning tasks) | ||
| /// </para> | ||
| /// <para><b>For Beginners:</b> Think of a graph as a social network: | ||
| /// - **Nodes**: People in the network | ||
| /// - **Edges**: Friendships or connections between people | ||
| /// - **Node Features**: Each person's attributes (age, interests, etc.) | ||
| /// - **Edge Features**: Relationship attributes (how long they've been friends, interaction frequency) | ||
| /// - **Labels**: What we want to predict (e.g., will this person like a product?) | ||
| /// | ||
| /// This class packages all this information together for graph neural network training. | ||
| /// </para> | ||
| /// </remarks> | ||
| public class GraphData<T> | ||
| { | ||
| /// <summary> | ||
| /// Node feature matrix of shape [num_nodes, num_features]. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// Each row represents one node's feature vector. For example, in a molecular graph, | ||
| /// features might include atom type, charge, hybridization, etc. | ||
| /// </remarks> | ||
| public Tensor<T> NodeFeatures { get; set; } = new Tensor<T>([0, 0]); | ||
|
|
||
| /// <summary> | ||
| /// Edge index tensor of shape [2, num_edges] or [num_edges, 2]. | ||
| /// Format: [source_nodes; target_nodes] or [[src, tgt], [src, tgt], ...]. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// <para> | ||
| /// Stores graph connectivity in COO (Coordinate) format. Each edge is represented by | ||
| /// a (source, target) pair of node indices. | ||
| /// </para> | ||
| /// <para><b>For Beginners:</b> If node 0 connects to node 1, and node 1 connects to node 2: | ||
| /// EdgeIndex = [[0, 1], [1, 2]] or transposed as [[0, 1], [1, 2]] | ||
| /// This is a compact way to store which nodes are connected. | ||
| /// </para> | ||
| /// </remarks> | ||
| public Tensor<T> EdgeIndex { get; set; } = new Tensor<T>([0, 2]); | ||
|
|
||
| /// <summary> | ||
| /// Optional edge feature matrix of shape [num_edges, num_edge_features]. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// Each row contains features for one edge. In molecular graphs, this could be | ||
| /// bond type, bond length, stereochemistry, etc. | ||
| /// </remarks> | ||
| public Tensor<T>? EdgeFeatures { get; set; } | ||
|
|
||
| /// <summary> | ||
| /// Adjacency matrix of shape [num_nodes, num_nodes] or [batch_size, num_nodes, num_nodes]. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// Square matrix where A[i,j] = 1 if edge exists from node i to j, 0 otherwise. | ||
| /// Can be weighted for graphs with edge weights. | ||
| /// </remarks> | ||
| public Tensor<T>? AdjacencyMatrix { get; set; } | ||
|
|
||
| /// <summary> | ||
| /// Node labels for node-level tasks (e.g., node classification). | ||
| /// Shape: [num_nodes] or [num_nodes, num_classes]. | ||
| /// </summary> | ||
| public Tensor<T>? NodeLabels { get; set; } | ||
|
|
||
| /// <summary> | ||
| /// Graph-level label for graph-level tasks (e.g., graph classification). | ||
| /// Shape: [1] or [num_classes]. | ||
| /// </summary> | ||
| public Tensor<T>? GraphLabel { get; set; } | ||
|
|
||
| /// <summary> | ||
| /// Mask indicating which nodes are in the training set. | ||
| /// </summary> | ||
| public Tensor<T>? TrainMask { get; set; } | ||
|
|
||
| /// <summary> | ||
| /// Mask indicating which nodes are in the validation set. | ||
| /// </summary> | ||
| public Tensor<T>? ValMask { get; set; } | ||
|
|
||
| /// <summary> | ||
| /// Mask indicating which nodes are in the test set. | ||
| /// </summary> | ||
| public Tensor<T>? TestMask { get; set; } | ||
|
|
||
| /// <summary> | ||
| /// Number of nodes in the graph. | ||
| /// </summary> | ||
| public int NumNodes => NodeFeatures.Shape[0]; | ||
|
|
||
| /// <summary> | ||
| /// Number of edges in the graph. | ||
| /// </summary> | ||
| public int NumEdges => EdgeIndex.Shape[0]; | ||
|
|
||
| /// <summary> | ||
| /// Number of node features. | ||
| /// </summary> | ||
| public int NumNodeFeatures => NodeFeatures.Shape.Length > 1 ? NodeFeatures.Shape[1] : 0; | ||
|
|
||
| /// <summary> | ||
| /// Number of edge features (0 if no edge features). | ||
| /// </summary> | ||
| public int NumEdgeFeatures => EdgeFeatures?.Shape[1] ?? 0; | ||
|
|
||
| /// <summary> | ||
| /// Metadata for heterogeneous graphs (optional). | ||
| /// </summary> | ||
| public Dictionary<string, object>? Metadata { get; set; } | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NumEdgesconflicts with documentedEdgeIndexshapesXML docs state
EdgeIndexmay be[2, num_edges]or[num_edges, 2], butonly works correctly for the
[num_edges, 2]convention. For[2, num_edges]it would always return 2.To support both documented layouts, consider:
Also applies to: 105-108
🤖 Prompt for AI Agents