Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions src/Data/Abstractions/GraphClassificationTask.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
using AiDotNet.LinearAlgebra;

namespace AiDotNet.Data.Abstractions;

/// <summary>
/// Represents a graph classification task where the goal is to classify entire graphs.
/// </summary>
/// <typeparam name="T">The numeric type used for calculations, typically float or double.</typeparam>
/// <remarks>
/// <para>
/// Graph classification assigns a label to an entire graph based on its structure and node/edge features.
/// Unlike node classification (classify individual nodes) or link prediction (predict edges),
/// graph classification treats the whole graph as a single data point.
/// </para>
/// <para><b>For Beginners:</b> Graph classification is like determining the category of a complex object.
///
/// **Real-world examples:**
///
/// **Molecular Property Prediction:**
/// - Input: Molecular graph (atoms as nodes, bonds as edges)
/// - Task: Predict molecular properties
/// - Examples:
/// * Is this molecule toxic?
/// * What is the solubility?
/// * Will this be a good drug candidate?
/// - Dataset: ZINC, QM9, BACE
///
/// **Protein Function Prediction:**
/// - Input: Protein structure graph
/// - Task: Predict protein function or family
/// - How: Analyze amino acid sequences and 3D structure
///
/// **Chemical Reaction Prediction:**
/// - Input: Reaction graph showing reactants and products
/// - Task: Predict reaction type or outcome
///
/// **Social Network Analysis:**
/// - Input: Community subgraphs
/// - Task: Classify community type or behavior
/// - Example: Identify bot networks vs organic communities
///
/// **Code Analysis:**
/// - Input: Abstract syntax tree (AST) or control flow graph
/// - Task: Detect bugs, classify code functionality
/// - Example: "Is this code snippet vulnerable to SQL injection?"
///
/// **Key Challenge:** Graph-level representation
/// - Must aggregate information from all nodes and edges
/// - Common approaches: Global pooling, hierarchical pooling, set2set
/// </para>
/// </remarks>
public class GraphClassificationTask<T>
{
/// <summary>
/// List of training graphs.
/// </summary>
/// <remarks>
/// Each graph in the list is an independent sample with its own structure and features.
/// </remarks>
public List<GraphData<T>> TrainGraphs { get; set; } = new List<GraphData<T>>();

/// <summary>
/// List of validation graphs.
/// </summary>
public List<GraphData<T>> ValGraphs { get; set; } = new List<GraphData<T>>();

/// <summary>
/// List of test graphs.
/// </summary>
public List<GraphData<T>> TestGraphs { get; set; } = new List<GraphData<T>>();

/// <summary>
/// Labels for training graphs.
/// Shape: [num_train_graphs] or [num_train_graphs, num_classes] for multi-label.
/// </summary>
public Tensor<T> TrainLabels { get; set; } = new Tensor<T>([0]);

/// <summary>
/// Labels for validation graphs.
/// Shape: [num_val_graphs] or [num_val_graphs, num_classes].
/// </summary>
public Tensor<T> ValLabels { get; set; } = new Tensor<T>([0]);

/// <summary>
/// Labels for test graphs.
/// Shape: [num_test_graphs] or [num_test_graphs, num_classes].
/// </summary>
public Tensor<T> TestLabels { get; set; } = new Tensor<T>([0]);

/// <summary>
/// Number of classes in the classification task.
/// </summary>
public int NumClasses { get; set; }

/// <summary>
/// Whether this is a multi-label classification task.
/// </summary>
/// <remarks>
/// - False: Each graph has exactly one label (e.g., molecule is toxic or not)
/// - True: Each graph can have multiple labels (e.g., molecule has multiple properties)
/// </remarks>
public bool IsMultiLabel { get; set; } = false;

/// <summary>
/// Whether this is a regression task instead of classification.
/// </summary>
/// <remarks>
/// <para>
/// For regression tasks (e.g., predicting molecular energy), labels are continuous values
/// rather than discrete classes.
/// </para>
/// <para><b>For Beginners:</b> The difference between classification and regression:
/// - **Classification**: Predict categories (e.g., "toxic" vs "non-toxic")
/// - **Regression**: Predict continuous values (e.g., "solubility = 2.3 mg/L")
///
/// Examples:
/// - Classification: Is this molecule a good drug? (Yes/No)
/// - Regression: What is this molecule's binding affinity? (0.0 to 10.0)
/// </para>
/// </remarks>
public bool IsRegression { get; set; } = false;

/// <summary>
/// Average number of nodes per graph (for informational purposes).
/// </summary>
public double AvgNumNodes { get; set; }

/// <summary>
/// Average number of edges per graph (for informational purposes).
/// </summary>
public double AvgNumEdges { get; set; }
}
124 changes: 124 additions & 0 deletions src/Data/Abstractions/GraphData.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
using AiDotNet.LinearAlgebra;

namespace AiDotNet.Data.Abstractions;

/// <summary>
/// Represents a single graph with nodes, edges, features, and optional labels.
/// </summary>
/// <typeparam name="T">The numeric type used for calculations, typically float or double.</typeparam>
/// <remarks>
/// <para>
/// GraphData encapsulates all information about a graph structure including:
/// - Node features (attributes for each node)
/// - Edge indices (connections between nodes)
/// - Edge features (optional attributes for edges)
/// - Adjacency matrix (graph structure in matrix form)
/// - Labels (for supervised learning tasks)
/// </para>
/// <para><b>For Beginners:</b> Think of a graph as a social network:
/// - **Nodes**: People in the network
/// - **Edges**: Friendships or connections between people
/// - **Node Features**: Each person's attributes (age, interests, etc.)
/// - **Edge Features**: Relationship attributes (how long they've been friends, interaction frequency)
/// - **Labels**: What we want to predict (e.g., will this person like a product?)
///
/// This class packages all this information together for graph neural network training.
/// </para>
/// </remarks>
public class GraphData<T>
{
/// <summary>
/// Node feature matrix of shape [num_nodes, num_features].
/// </summary>
/// <remarks>
/// Each row represents one node's feature vector. For example, in a molecular graph,
/// features might include atom type, charge, hybridization, etc.
/// </remarks>
public Tensor<T> NodeFeatures { get; set; } = new Tensor<T>([0, 0]);

/// <summary>
/// Edge index tensor of shape [2, num_edges] or [num_edges, 2].
/// Format: [source_nodes; target_nodes] or [[src, tgt], [src, tgt], ...].
/// </summary>
/// <remarks>
/// <para>
/// Stores graph connectivity in COO (Coordinate) format. Each edge is represented by
/// a (source, target) pair of node indices.
/// </para>
/// <para><b>For Beginners:</b> If node 0 connects to node 1, and node 1 connects to node 2:
/// EdgeIndex = [[0, 1], [1, 2]] or transposed as [[0, 1], [1, 2]]
/// This is a compact way to store which nodes are connected.
/// </para>
/// </remarks>
public Tensor<T> EdgeIndex { get; set; } = new Tensor<T>([0, 2]);
Comment on lines +39 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

NumEdges conflicts with documented EdgeIndex shapes

XML docs state EdgeIndex may be [2, num_edges] or [num_edges, 2], but

public int NumEdges => EdgeIndex.Shape[0];

only works correctly for the [num_edges, 2] convention. For [2, num_edges] it would always return 2.

To support both documented layouts, consider:

-    /// <summary>
-    /// Number of edges in the graph.
-    /// </summary>
-    public int NumEdges => EdgeIndex.Shape[0];
+    /// <summary>
+    /// Number of edges in the graph.
+    /// Supports both [2, num_edges] and [num_edges, 2] layouts.
+    /// </summary>
+    public int NumEdges
+    {
+        get
+        {
+            if (EdgeIndex.Shape.Length < 2)
+            {
+                return 0;
+            }
+
+            // [2, num_edges] => use second dimension; otherwise assume [num_edges, 2].
+            return EdgeIndex.Shape[0] == 2
+                ? EdgeIndex.Shape[1]
+                : EdgeIndex.Shape[0];
+        }
+    }

Also applies to: 105-108

🤖 Prompt for AI Agents
In src/Data/Abstractions/GraphData.cs around lines 39-53 (and also update the
similar logic at lines 105-108), the NumEdges getter assumes EdgeIndex is laid
out as [num_edges, 2] by returning EdgeIndex.Shape[0], which is incorrect for
the documented alternative layout [2, num_edges]. Change NumEdges to detect the
layout: if EdgeIndex has rank 2 and one dimension equals 2 then return the other
dimension; if neither dimension equals 2 throw or assert a clear
InvalidOperationException with guidance; ensure the same corrected logic is
applied to the code at lines 105-108 and add a brief validation check for null
or unexpected ranks before accessing Shape.


/// <summary>
/// Optional edge feature matrix of shape [num_edges, num_edge_features].
/// </summary>
/// <remarks>
/// Each row contains features for one edge. In molecular graphs, this could be
/// bond type, bond length, stereochemistry, etc.
/// </remarks>
public Tensor<T>? EdgeFeatures { get; set; }

/// <summary>
/// Adjacency matrix of shape [num_nodes, num_nodes] or [batch_size, num_nodes, num_nodes].
/// </summary>
/// <remarks>
/// Square matrix where A[i,j] = 1 if edge exists from node i to j, 0 otherwise.
/// Can be weighted for graphs with edge weights.
/// </remarks>
public Tensor<T>? AdjacencyMatrix { get; set; }

/// <summary>
/// Node labels for node-level tasks (e.g., node classification).
/// Shape: [num_nodes] or [num_nodes, num_classes].
/// </summary>
public Tensor<T>? NodeLabels { get; set; }

/// <summary>
/// Graph-level label for graph-level tasks (e.g., graph classification).
/// Shape: [1] or [num_classes].
/// </summary>
public Tensor<T>? GraphLabel { get; set; }

/// <summary>
/// Mask indicating which nodes are in the training set.
/// </summary>
public Tensor<T>? TrainMask { get; set; }

/// <summary>
/// Mask indicating which nodes are in the validation set.
/// </summary>
public Tensor<T>? ValMask { get; set; }

/// <summary>
/// Mask indicating which nodes are in the test set.
/// </summary>
public Tensor<T>? TestMask { get; set; }

/// <summary>
/// Number of nodes in the graph.
/// </summary>
public int NumNodes => NodeFeatures.Shape[0];

/// <summary>
/// Number of edges in the graph.
/// </summary>
public int NumEdges => EdgeIndex.Shape[0];

/// <summary>
/// Number of node features.
/// </summary>
public int NumNodeFeatures => NodeFeatures.Shape.Length > 1 ? NodeFeatures.Shape[1] : 0;

/// <summary>
/// Number of edge features (0 if no edge features).
/// </summary>
public int NumEdgeFeatures => EdgeFeatures?.Shape[1] ?? 0;

/// <summary>
/// Metadata for heterogeneous graphs (optional).
/// </summary>
public Dictionary<string, object>? Metadata { get; set; }
}
Loading