Skip to content

Commit a34e212

Browse files
committed
Merge branch 'claude/add-nested-learning-011CV1ZosqMVM7gWV67ayo6p' of http://127.0.0.1:54633/git/ooples/AiDotNet into claude/add-nested-learning-011CV1ZosqMVM7gWV67ayo6p
2 parents 25888a0 + bcbcf79 commit a34e212

File tree

6 files changed

+1933
-0
lines changed

6 files changed

+1933
-0
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using AiDotNet.LinearAlgebra;
2+
3+
namespace AiDotNet.Interfaces;
4+
5+
/// <summary>
6+
/// Interface for self-supervised loss functions used in meta-learning.
7+
/// </summary>
8+
/// <typeparam name="T">The numeric type used for calculations (e.g., float, double).</typeparam>
9+
/// <remarks>
10+
/// <para>
11+
/// Self-supervised learning creates artificial tasks from unlabeled data, allowing models
12+
/// to learn useful representations without explicit labels. This is particularly valuable
13+
/// in meta-learning where the query set is often large but unlabeled.
14+
/// </para>
15+
/// <para><b>For Beginners:</b> Self-supervised learning is like learning by creating your own practice problems.
16+
///
17+
/// Example: Rotation prediction for images
18+
/// - Take an unlabeled image
19+
/// - Rotate it by 0°, 90°, 180°, or 270°
20+
/// - Train the model to predict which rotation was applied
21+
/// - The model learns spatial relationships and features without needing class labels
22+
///
23+
/// This is powerful because:
24+
/// 1. You can use unlabeled data (which is often abundant)
25+
/// 2. The model learns useful features automatically
26+
/// 3. These features help with the actual task (classification, etc.)
27+
///
28+
/// Think of it like learning to recognize faces by first learning to identify if a photo is upside down.
29+
/// You don't need to know who the person is to learn about facial features!
30+
/// </para>
31+
/// <para><b>Common Self-Supervised Tasks:</b>
32+
/// - <b>Rotation Prediction:</b> Predict rotation angle (0°, 90°, 180°, 270°)
33+
/// - <b>Jigsaw Puzzles:</b> Solve scrambled image patches
34+
/// - <b>Colorization:</b> Predict color from grayscale
35+
/// - <b>Context Prediction:</b> Predict spatial relationships between patches
36+
/// - <b>Contrastive Learning:</b> Learn to distinguish similar vs dissimilar examples
37+
/// </para>
38+
/// </remarks>
39+
public interface ISelfSupervisedLoss<T>
40+
{
41+
/// <summary>
42+
/// Creates a self-supervised task from unlabeled input data.
43+
/// </summary>
44+
/// <param name="input">Unlabeled input data (e.g., images).</param>
45+
/// <returns>
46+
/// A tuple containing:
47+
/// - augmentedX: Transformed input data for the self-supervised task
48+
/// - augmentedY: Labels for the self-supervised task (e.g., rotation angles)
49+
/// </returns>
50+
/// <remarks>
51+
/// <para>
52+
/// This method transforms unlabeled data into a supervised learning problem
53+
/// by creating artificial labels based on the transformation applied.
54+
/// </para>
55+
/// <para><b>For Beginners:</b> This converts "unlabeled data" into a "labeled learning problem".
56+
///
57+
/// Example for rotation prediction:
58+
/// - Input: 10 unlabeled images
59+
/// - Output: 40 labeled images (each original rotated 4 times: 0°, 90°, 180°, 270°)
60+
/// - Labels: [0, 1, 2, 3] indicating which rotation was applied
61+
///
62+
/// The model learns to recognize rotations, which teaches it about:
63+
/// - Edge orientations
64+
/// - Spatial relationships
65+
/// - Object structure
66+
///
67+
/// These learned features are useful for the actual classification task!
68+
/// </para>
69+
/// </remarks>
70+
(TInput augmentedX, TOutput augmentedY) CreateTask<TInput, TOutput>(TInput input);
71+
}
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
using AiDotNet.Helpers;
2+
using AiDotNet.Interfaces;
3+
using AiDotNet.LinearAlgebra;
4+
5+
namespace AiDotNet.LossFunctions;
6+
7+
/// <summary>
8+
/// Self-supervised loss function based on rotation prediction for images.
9+
/// </summary>
10+
/// <typeparam name="T">The numeric data type (e.g., float, double).</typeparam>
11+
/// <remarks>
12+
/// <para>
13+
/// Rotation prediction is a self-supervised task where:
14+
/// 1. Images are rotated by 0°, 90°, 180°, or 270°
15+
/// 2. Model predicts which rotation was applied (4-class classification)
16+
/// 3. Model learns spatial relationships and features without needing class labels
17+
/// </para>
18+
/// <para><b>For Beginners:</b> This teaches the model to understand image structure without labels.
19+
///
20+
/// Imagine showing someone 100 photos, each rotated randomly:
21+
/// - They learn to recognize: which way is "up", spatial relationships, object orientations
22+
/// - They don't need to know: what the objects are (no labels needed)
23+
///
24+
/// After this training, when you show them 5 labeled cat photos:
25+
/// - They already understand image structure
26+
/// - They just need to learn: "cats look like THIS"
27+
/// - Much faster than learning everything from scratch!
28+
///
29+
/// <b>How it works:</b>
30+
/// 1. Take each unlabeled image
31+
/// 2. Create 4 versions: rotated by 0°, 90°, 180°, 270°
32+
/// 3. Label each version: 0, 1, 2, 3 (which rotation was applied)
33+
/// 4. Train model to predict the rotation
34+
///
35+
/// <b>What the model learns:</b>
36+
/// - Edge orientations
37+
/// - Spatial relationships
38+
/// - Object structure
39+
/// - "Natural" vs "unnatural" orientations
40+
///
41+
/// These features are very useful for actual classification tasks!
42+
/// </para>
43+
/// </remarks>
44+
public class RotationPredictionLoss<T> : ISelfSupervisedLoss<T>
45+
{
46+
private static readonly INumericOperations<T> NumOps = MathHelper.GetNumericOperations<T>();
47+
48+
/// <inheritdoc/>
49+
public (TInput augmentedX, TOutput augmentedY) CreateTask<TInput, TOutput>(TInput input)
50+
{
51+
if (input is not Tensor<T> tensorInput)
52+
{
53+
throw new NotSupportedException(
54+
$"RotationPredictionLoss only supports Tensor<T> input, but received {typeof(TInput)}");
55+
}
56+
57+
// Validate input shape (should be [N, H, W] or [N, H, W, C])
58+
if (tensorInput.Shape.Length < 3)
59+
{
60+
throw new ArgumentException(
61+
$"Input tensor must have at least 3 dimensions [N, H, W] or [N, H, W, C], " +
62+
$"but got shape [{string.Join(", ", tensorInput.Shape)}]");
63+
}
64+
65+
int numImages = tensorInput.Shape[0];
66+
int height = tensorInput.Shape[1];
67+
int width = tensorInput.Shape[2];
68+
int channels = tensorInput.Shape.Length > 3 ? tensorInput.Shape[3] : 1;
69+
70+
// Create rotated versions (4 rotations per image)
71+
int totalRotatedImages = numImages * 4;
72+
var augmentedX = new Tensor<T>(new[] { totalRotatedImages, height, width, channels });
73+
var augmentedY = new Tensor<T>(new[] { totalRotatedImages, 4 }); // 4-class one-hot
74+
75+
int outputIdx = 0;
76+
for (int imgIdx = 0; imgIdx < numImages; imgIdx++)
77+
{
78+
// Create 4 rotations (0°, 90°, 180°, 270°)
79+
for (int rotationClass = 0; rotationClass < 4; rotationClass++)
80+
{
81+
// Copy rotated image to output
82+
RotateAndCopy(tensorInput, augmentedX, imgIdx, outputIdx, rotationClass, height, width, channels);
83+
84+
// Store rotation label (one-hot encoding)
85+
for (int classIdx = 0; classIdx < 4; classIdx++)
86+
{
87+
augmentedY[outputIdx, classIdx] = (classIdx == rotationClass) ? NumOps.One : NumOps.Zero;
88+
}
89+
90+
outputIdx++;
91+
}
92+
}
93+
94+
return ((TInput)(object)augmentedX, (TOutput)(object)augmentedY);
95+
}
96+
97+
/// <summary>
98+
/// Rotates an image and copies it to the destination tensor.
99+
/// </summary>
100+
/// <param name="source">Source tensor containing images.</param>
101+
/// <param name="dest">Destination tensor for rotated images.</param>
102+
/// <param name="srcIdx">Index of source image.</param>
103+
/// <param name="destIdx">Index in destination tensor.</param>
104+
/// <param name="rotationClass">Rotation class (0=0°, 1=90°, 2=180°, 3=270°).</param>
105+
/// <param name="height">Image height.</param>
106+
/// <param name="width">Image width.</param>
107+
/// <param name="channels">Number of color channels.</param>
108+
/// <remarks>
109+
/// <b>Note:</b> This implementation assumes square images for correct rotation behavior.
110+
/// For non-square images, 90° and 270° rotations will result in distorted images,
111+
/// since the output dimensions remain [height, width] instead of swapping to [width, height].
112+
/// </remarks>
113+
private void RotateAndCopy(
114+
Tensor<T> source,
115+
Tensor<T> dest,
116+
int srcIdx,
117+
int destIdx,
118+
int rotationClass,
119+
int height,
120+
int width,
121+
int channels)
122+
{
123+
switch (rotationClass)
124+
{
125+
case 0:
126+
// No rotation (0°)
127+
CopyImage(source, dest, srcIdx, destIdx, height, width, channels,
128+
(i, j) => (i, j));
129+
break;
130+
131+
case 1:
132+
// Rotate 90° clockwise: (i, j) → (j, height-1-i)
133+
CopyImage(source, dest, srcIdx, destIdx, height, width, channels,
134+
(i, j) => (j, height - 1 - i));
135+
break;
136+
137+
case 2:
138+
// Rotate 180°: (i, j) → (height-1-i, width-1-j)
139+
CopyImage(source, dest, srcIdx, destIdx, height, width, channels,
140+
(i, j) => (height - 1 - i, width - 1 - j));
141+
break;
142+
143+
case 3:
144+
// Rotate 270° clockwise (90° counter-clockwise): (i, j) → (width-1-j, i)
145+
CopyImage(source, dest, srcIdx, destIdx, height, width, channels,
146+
(i, j) => (width - 1 - j, i));
147+
break;
148+
149+
default:
150+
throw new ArgumentException($"Invalid rotation class: {rotationClass}. Must be 0-3.");
151+
}
152+
}
153+
154+
/// <summary>
155+
/// Copies an image with a coordinate transformation.
156+
/// </summary>
157+
/// <param name="source">Source tensor.</param>
158+
/// <param name="dest">Destination tensor.</param>
159+
/// <param name="srcIdx">Source image index.</param>
160+
/// <param name="destIdx">Destination image index.</param>
161+
/// <param name="height">Image height.</param>
162+
/// <param name="width">Image width.</param>
163+
/// <param name="channels">Number of channels.</param>
164+
/// <param name="transform">Coordinate transformation function (srcCoord → destCoord).</param>
165+
private void CopyImage(
166+
Tensor<T> source,
167+
Tensor<T> dest,
168+
int srcIdx,
169+
int destIdx,
170+
int height,
171+
int width,
172+
int channels,
173+
Func<int, int, (int, int)> transform)
174+
{
175+
for (int i = 0; i < height; i++)
176+
{
177+
for (int j = 0; j < width; j++)
178+
{
179+
var (destI, destJ) = transform(i, j);
180+
181+
// Handle 3D tensors [N, H, W] (grayscale)
182+
if (source.Shape.Length == 3)
183+
{
184+
dest[destIdx, destI, destJ] = source[srcIdx, i, j];
185+
}
186+
// Handle 4D tensors [N, H, W, C] (color images)
187+
else
188+
{
189+
for (int c = 0; c < channels; c++)
190+
{
191+
dest[destIdx, destI, destJ, c] = source[srcIdx, i, j, c];
192+
}
193+
}
194+
}
195+
}
196+
}
197+
}

0 commit comments

Comments
 (0)