-
-
Notifications
You must be signed in to change notification settings - Fork 7
Fix issue 421 in AiDotNet repository #436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,327 @@ | ||||||||||||||||||||||
| using System.Numerics; | ||||||||||||||||||||||
| using AiDotNet.Interfaces; | ||||||||||||||||||||||
| using AiDotNet.Models; | ||||||||||||||||||||||
| using AiDotNet.Models.Options; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| namespace AiDotNet.AdversarialRobustness.Alignment; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// <summary> | ||||||||||||||||||||||
| /// Implements Reinforcement Learning from Human Feedback (RLHF) for AI alignment. | ||||||||||||||||||||||
| /// </summary> | ||||||||||||||||||||||
| /// <remarks> | ||||||||||||||||||||||
| /// <para> | ||||||||||||||||||||||
| /// RLHF trains models to align with human preferences by learning a reward model | ||||||||||||||||||||||
| /// from human feedback and using it to fine-tune the model via reinforcement learning. | ||||||||||||||||||||||
| /// </para> | ||||||||||||||||||||||
| /// <para><b>For Beginners:</b> RLHF is like having a human teacher grade the AI's responses | ||||||||||||||||||||||
| /// and using those grades to improve the AI. The AI learns what humans prefer and adjusts | ||||||||||||||||||||||
| /// its behavior accordingly. This is how models like ChatGPT learn to be helpful and follow | ||||||||||||||||||||||
| /// instructions.</para> | ||||||||||||||||||||||
| /// <para> | ||||||||||||||||||||||
| /// Original approaches: "Learning to summarize from human feedback" (OpenAI, 2020), | ||||||||||||||||||||||
| /// "Training language models to follow instructions with human feedback" (InstructGPT, 2022) | ||||||||||||||||||||||
| /// </para> | ||||||||||||||||||||||
| /// </remarks> | ||||||||||||||||||||||
| /// <typeparam name="T">The numeric data type used for calculations.</typeparam> | ||||||||||||||||||||||
| public class RLHFAlignment<T> : IAlignmentMethod<T> | ||||||||||||||||||||||
| where T : struct, INumber<T> | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| private readonly AlignmentMethodOptions<T> options; | ||||||||||||||||||||||
| private Func<T[], T[], double>? rewardModel; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// <summary> | ||||||||||||||||||||||
| /// Initializes a new instance of RLHF alignment. | ||||||||||||||||||||||
| /// </summary> | ||||||||||||||||||||||
| /// <param name="options">The alignment configuration options.</param> | ||||||||||||||||||||||
| public RLHFAlignment(AlignmentMethodOptions<T> options) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| this.options = options; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// <inheritdoc/> | ||||||||||||||||||||||
| public Func<T[], T[]> AlignModel(Func<T[], T[]> baseModel, AlignmentFeedbackData<T> feedbackData) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| // Step 1: Train a reward model from human preferences | ||||||||||||||||||||||
| rewardModel = TrainRewardModel(feedbackData); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Step 2: Fine-tune the policy model using the reward model | ||||||||||||||||||||||
| var alignedModel = FinetuneWithRL(baseModel, feedbackData, rewardModel); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return alignedModel; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// <inheritdoc/> | ||||||||||||||||||||||
| public AlignmentMetrics<T> EvaluateAlignment(Func<T[], T[]> model, AlignmentEvaluationData<T> evaluationData) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| var metrics = new AlignmentMetrics<T>(); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| int helpfulCount = 0; | ||||||||||||||||||||||
| int harmlessCount = 0; | ||||||||||||||||||||||
| int honestCount = 0; | ||||||||||||||||||||||
| double totalPreferenceMatch = 0.0; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for (int i = 0; i < evaluationData.TestInputs.Length; i++) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| var output = model(evaluationData.TestInputs[i]); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Evaluate helpfulness (simplified) | ||||||||||||||||||||||
| if (IsHelpful(output, evaluationData.ExpectedOutputs[i])) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| helpfulCount++; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Evaluate harmlessness (simplified) | ||||||||||||||||||||||
| if (IsHarmless(output)) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| harmlessCount++; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Evaluate honesty (simplified) | ||||||||||||||||||||||
| if (IsHonest(output, evaluationData.TestInputs[i])) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| honestCount++; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Preference matching | ||||||||||||||||||||||
| if (i < evaluationData.ReferenceScores.Length) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| var predictedScore = rewardModel?.Invoke(evaluationData.TestInputs[i], output) ?? 0.5; | ||||||||||||||||||||||
| var referenceScore = evaluationData.ReferenceScores[i]; | ||||||||||||||||||||||
| totalPreferenceMatch += 1.0 - Math.Abs(predictedScore - referenceScore); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| int total = evaluationData.TestInputs.Length; | ||||||||||||||||||||||
| metrics.HelpfulnessScore = (double)helpfulCount / total; | ||||||||||||||||||||||
| metrics.HarmlessnessScore = (double)harmlessCount / total; | ||||||||||||||||||||||
| metrics.HonestyScore = (double)honestCount / total; | ||||||||||||||||||||||
| metrics.PreferenceMatchRate = totalPreferenceMatch / total; | ||||||||||||||||||||||
| metrics.OverallAlignmentScore = (metrics.HelpfulnessScore + metrics.HarmlessnessScore + metrics.HonestyScore) / 3.0; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return metrics; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// <inheritdoc/> | ||||||||||||||||||||||
| public Func<T[], T[]> ApplyConstitutionalPrinciples(Func<T[], T[]> model, string[] principles) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| // Wrap the model with constitutional AI principles | ||||||||||||||||||||||
| return (input) => | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| // Generate initial response | ||||||||||||||||||||||
| var initialResponse = model(input); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Critique and revise based on principles | ||||||||||||||||||||||
| for (int i = 0; i < options.CritiqueIterations; i++) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| var critique = GenerateCritique(initialResponse, principles); | ||||||||||||||||||||||
| initialResponse = ReviseBasedOnCritique(model, input, initialResponse, critique); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return initialResponse; | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// <inheritdoc/> | ||||||||||||||||||||||
| public RedTeamingResults<T> PerformRedTeaming(Func<T[], T[]> model, T[][] adversarialPrompts) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| var results = new RedTeamingResults<T> | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| AdversarialPrompts = adversarialPrompts, | ||||||||||||||||||||||
| ModelResponses = new T[adversarialPrompts.Length][], | ||||||||||||||||||||||
| SuccessfulAttacks = new bool[adversarialPrompts.Length], | ||||||||||||||||||||||
| SeverityScores = new double[adversarialPrompts.Length], | ||||||||||||||||||||||
| VulnerabilityTypes = new string[adversarialPrompts.Length], | ||||||||||||||||||||||
| Vulnerabilities = new List<VulnerabilityReport>() | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| int successfulAttacks = 0; | ||||||||||||||||||||||
| double totalSeverity = 0.0; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for (int i = 0; i < adversarialPrompts.Length; i++) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| var response = model(adversarialPrompts[i]); | ||||||||||||||||||||||
| results.ModelResponses[i] = response; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Analyze response for misalignment | ||||||||||||||||||||||
| var (isVulnerable, severity, vulnType) = AnalyzeResponseForVulnerability(response); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| results.SuccessfulAttacks[i] = isVulnerable; | ||||||||||||||||||||||
| results.SeverityScores[i] = severity; | ||||||||||||||||||||||
| results.VulnerabilityTypes[i] = vulnType; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if (isVulnerable) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| successfulAttacks++; | ||||||||||||||||||||||
| totalSeverity += severity; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| results.Vulnerabilities.Add(new VulnerabilityReport | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| Type = vulnType, | ||||||||||||||||||||||
| Severity = severity, | ||||||||||||||||||||||
| Description = $"Model showed misaligned behavior of type: {vulnType}", | ||||||||||||||||||||||
| ExamplePrompt = ConvertToString(adversarialPrompts[i]), | ||||||||||||||||||||||
| ProblematicResponse = ConvertToString(response), | ||||||||||||||||||||||
| Recommendations = new[] | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "Add safety filters", | ||||||||||||||||||||||
| "Improve RLHF training data", | ||||||||||||||||||||||
| "Strengthen constitutional principles" | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| }); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| results.SuccessRate = (double)successfulAttacks / adversarialPrompts.Length; | ||||||||||||||||||||||
| results.AverageSeverity = successfulAttacks > 0 ? totalSeverity / successfulAttacks : 0.0; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return results; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// <inheritdoc/> | ||||||||||||||||||||||
| public AlignmentMethodOptions<T> GetOptions() => options; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// <inheritdoc/> | ||||||||||||||||||||||
| public void Reset() { } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// <inheritdoc/> | ||||||||||||||||||||||
| public byte[] Serialize() | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| var json = System.Text.Json.JsonSerializer.Serialize(options); | ||||||||||||||||||||||
| return System.Text.Encoding.UTF8.GetBytes(json); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// <inheritdoc/> | ||||||||||||||||||||||
| public void Deserialize(byte[] data) { } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// <inheritdoc/> | ||||||||||||||||||||||
| public void SaveModel(string filePath) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| File.WriteAllBytes(filePath, Serialize()); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// <inheritdoc/> | ||||||||||||||||||||||
| public void LoadModel(string filePath) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| Deserialize(File.ReadAllBytes(filePath)); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| private Func<T[], T[], double> TrainRewardModel(AlignmentFeedbackData<T> feedbackData) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| // Train a reward model from human preference comparisons | ||||||||||||||||||||||
| // This is a simplified placeholder - real implementation would use neural networks | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return (input, output) => | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| // Simple reward heuristic based on output characteristics | ||||||||||||||||||||||
| var outputSum = output.Sum(x => double.CreateChecked(x)); | ||||||||||||||||||||||
| var outputMean = outputSum / output.Length; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Higher reward for moderate values (heuristic for "reasonable" outputs) | ||||||||||||||||||||||
| var reward = 1.0 - Math.Abs(outputMean - 0.5); | ||||||||||||||||||||||
| return Math.Clamp(reward, 0.0, 1.0); | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| private Func<T[], T[]> FinetuneWithRL(Func<T[], T[]> baseModel, AlignmentFeedbackData<T> feedbackData, Func<T[], T[], double> rewardModel) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| // Simplified PPO-like fine-tuning | ||||||||||||||||||||||
| // Real implementation would integrate with a RL framework | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return (input) => | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| var output = baseModel(input); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Apply KL penalty to stay close to base model | ||||||||||||||||||||||
| var klPenalty = options.KLCoefficient; | ||||||||||||||||||||||
|
||||||||||||||||||||||
| var klPenalty = options.KLCoefficient; |
Copilot
AI
Nov 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The array is iterated twice separately for Max and Min. Consider iterating once to find both values simultaneously for better performance.
| var maxValue = output.Max(x => double.CreateChecked(x)); | |
| var minValue = output.Min(x => double.CreateChecked(x)); | |
| double maxValue = double.MinValue; | |
| double minValue = double.MaxValue; | |
| foreach (var x in output) | |
| { | |
| var val = double.CreateChecked(x); | |
| if (val > maxValue) maxValue = val; | |
| if (val < minValue) minValue = val; | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Local scope variable 'rewardModel' shadows RLHFAlignment`1.rewardModel.