Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
45 changes: 21 additions & 24 deletions Generator/CodeGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,37 +1,20 @@
using System.Text;
using System.Threading.Channels;
namespace System.Management.Generator;

public class CodeGenerator
public class CodeGenerator(ChannelReader<ClassDefinition> channel)
{
private static readonly HashSet<string> _excludedFolders = new(StringComparer.OrdinalIgnoreCase) { "bin", "obj" };

private static DirectoryInfo FindTypesDirectory()
{
var dir = new DirectoryInfo(Environment.CurrentDirectory);
while (dir != null)
{
var typesDir = new DirectoryInfo(Path.Combine(dir.FullName, "Types"));
if (typesDir.Exists)
return typesDir;
dir = dir.Parent;
}
throw new DirectoryNotFoundException("Could not find 'Types' directory in any parent directory.");
}
private readonly DirectoryInfo _targetDirectory = FindTypesDirectory();
private readonly Dictionary<string, ClassDefinition> _classDefinitions = new(StringComparer.OrdinalIgnoreCase);

private readonly Dictionary<string, ClassDefinition> _classDefinitions;
private readonly DirectoryInfo _targetDirectory;

public CodeGenerator(IEnumerable<ClassDefinition> classDefinitions)
{
_classDefinitions = classDefinitions.ToDictionary(t => t.ClassName, t => t, StringComparer.InvariantCultureIgnoreCase);
_targetDirectory = FindTypesDirectory();
}

public void GenerateCode()
public async Task GenerateCode()
{
var existingFiles = _targetDirectory.GetDirectories().Where(d => !_excludedFolders.Contains(d.Name)).SelectMany(d => d.GetFiles("*.g.cs", SearchOption.AllDirectories)).Select(fi => fi.FullName).ToHashSet(StringComparer.OrdinalIgnoreCase);
foreach (var typeDefinition in _classDefinitions.Values)
await foreach (var typeDefinition in channel.ReadAllAsync())
{
_classDefinitions[typeDefinition.ClassName] = typeDefinition;
Console.WriteLine($"Generating class for {typeDefinition.ClassName}.");
(var namespaceName, var className) = ParseClassName(typeDefinition.ClassName);

Expand Down Expand Up @@ -92,6 +75,19 @@ public void GenerateCode()
}
}

private static DirectoryInfo FindTypesDirectory()
{
var dir = new DirectoryInfo(Environment.CurrentDirectory);
while (dir != null)
{
var typesDir = new DirectoryInfo(Path.Combine(dir.FullName, "Types"));
if (typesDir.Exists)
return typesDir;
dir = dir.Parent;
}
throw new DirectoryNotFoundException("Could not find 'Types' directory in any parent directory.");
}

private IEnumerable<PropertyDefinition> GetInheritedPropertiesFrom(string? superClass)
=> superClass != null && _classDefinitions.TryGetValue(superClass, out var td) ? td.Properties.Concat(GetInheritedPropertiesFrom(td.SuperClass)) : [];

Expand All @@ -109,6 +105,7 @@ private string GenerateClassCode(string namespaceName, string className, string
* Any changes made to this file will be overwritten. *
* *
**************************************************************/
#nullable enable
""");
sb.AppendLine($"namespace System.Management.Types.{namespaceName};");
sb.AppendLine();
Expand Down
137 changes: 67 additions & 70 deletions Generator/DefinitionLoader.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using System.Diagnostics.CodeAnalysis;
using System.Text.RegularExpressions;
using System.Threading.Channels;
namespace System.Management.Generator;

internal partial class DefinitionLoader(IEnumerable<string> classNames)
internal partial class DefinitionLoader(ChannelWriter<ClassDefinition> channel)
{

[GeneratedRegex(@"<code class=""lang-syntax"">([^<]*)</code>")]
Expand Down Expand Up @@ -33,68 +34,58 @@ internal partial class DefinitionLoader(IEnumerable<string> classNames)
private static partial Regex GetLinkRegex();
private static readonly Regex LinkRegex = GetLinkRegex();

private readonly Dictionary<string, ClassDefinition> _classDefinitions = classNames.ToDictionary(n => n, n => default(ClassDefinition), StringComparer.OrdinalIgnoreCase);
public IEnumerable<ClassDefinition> LoadedClassDefinitions => _classDefinitions.Values;
private readonly Dictionary<string, string> _typeMismatches = new() { ["Win32_LogicalElement"] = "CIM_LogicalElement" };
private readonly HashSet<string> _loadedClassDefinitions = new(StringComparer.OrdinalIgnoreCase);

private string? CheckClass(string? className)
public async Task Load(IEnumerable<string> classesToLoad)
{
foreach (var className in classesToLoad)
{
await TryLoadClass(className);
}

channel.Complete();
}

private async Task<string?> TryLoadClass(string? className)
{
if (className == null)
{
return null;
}

if ("Win32_LogicalElement".Equals(className))
if (_typeMismatches.TryGetValue(className, out var correctClassName))
{
return CheckClass("CIM_LogicalElement");
return await TryLoadClass(correctClassName);
}

if (!_classDefinitions.ContainsKey(className))
if (!_loadedClassDefinitions.Add(className))
{
if (className.IndexOf('_') == -1)
{
return CheckClass($"__{className}");
}

_classDefinitions.Add(className, default);
Console.WriteLine($"Met new class {className}.");
return className;
}
return className;
}

public async Task Load()
{
foreach (var className in GetUnloadedClassNames())
if (await LoadClassDefinition(className) is ClassDefinition classDefinition)
{
if (await LoadClassDefinition(className) is ClassDefinition classDefinition)
{
_classDefinitions[className] = classDefinition;
}
else
{
_classDefinitions.Remove(className);
}
Console.WriteLine($"Loaded info for {className}:");
await channel.WriteAsync(classDefinition);
return className;
}
}

private IEnumerable<string> GetUnloadedClassNames()
{
var classNames = _classDefinitions.Where(kvp => kvp.Value.ClassName == null).Select(kvp => kvp.Key).ToArray();
while (classNames.Length > 0)
else if (className.IndexOf('_') == -1)
{
foreach (var className in classNames)
{
yield return className;
}
classNames = _classDefinitions.Where(kvp => kvp.Value.ClassName == null).Select(kvp => kvp.Key).ToArray();
var prefixedClassName = $"__{className}";
_typeMismatches[className] = prefixedClassName;
return await TryLoadClass(prefixedClassName);
}

return null;
}

private async Task<ClassDefinition?> LoadClassDefinition(string className)
{
Console.WriteLine($"Getting info for {className}:");
Console.WriteLine($"Loading info for {className}:");
Uri? classUri = null;
Uri[] uris = className[0] == '_'
? [new Uri($"https://learn.microsoft.com/en-gb/windows/win32/wmisdk/{className.Replace('_', '-')}")]
? [new Uri($"https://learn.microsoft.com/en-us/windows/win32/wmisdk/{className.Replace('_', '-')}")]
: [new Uri($"https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/{className.Replace('_', '-')}"),
new Uri($"https://learn.microsoft.com/en-us/previous-versions/windows/desktop/secrcw32prov//{className.Replace('_', '-')}")];

Expand All @@ -111,15 +102,14 @@ private IEnumerable<string> GetUnloadedClassNames()
{
if (i + 1 == uris.Length)
{
ErrorReporter.Report($"Unable to find Microsoft Learn page to parse for {className}: {ex.Message}", ex);
ErrorReporter.Report($"Unable to find Microsoft Learn page to parse for {className}: {ex.Message}", ex, throwOrBreak: false);
break;
}
}
}

if (classUri == null || pageContents == null)
{
ErrorReporter.Report($"Failed to load {className}.");
return null;
}

Expand Down Expand Up @@ -151,14 +141,14 @@ private IEnumerable<string> GetUnloadedClassNames()
ErrorReporter.Report($"Encounterd different type {classDef[1]} when parsing data for {className}.");
}

superClass = CheckClass(classDef.Length > 2 ? classDef[^1] : null);
superClass = await TryLoadClass(classDef.Length > 2 ? classDef[^1] : null);

break;
}

var propertyBlock = pageContents.IndexOf("<h3 id=\"properties\"");
var endBlock = pageContents.IndexOf("<h3", propertyBlock + 1);
var properties = propertyBlock == -1 ? [] : ParseProperties(codeLines[(lineIndex + 2)..^1], pageContents[propertyBlock..endBlock]).ToList();
var properties = propertyBlock == -1 ? [] : await ParseProperties(codeLines[(lineIndex + 2)..^1], pageContents[propertyBlock..endBlock]);

var methodBlock = pageContents.IndexOf("<h3 id=\"methods\"");
endBlock = pageContents.IndexOf("<h3", methodBlock + 1);
Expand All @@ -177,31 +167,37 @@ private static async Task<string> GetPageContentsAsync(Uri url)
return await response.Content.ReadAsStringAsync();
}

private IEnumerable<PropertyDefinition> ParseProperties(string[] propertyLines, string propertiesBlock)
private async Task<List<PropertyDefinition>> ParseProperties(string[] propertyLines, string propertiesBlock)
{
foreach (var property in propertyLines.Select(ParsePropertyLine).OfType<PropertyDefinition>())
var result = new List<PropertyDefinition>(propertyLines.Length);
foreach (var propertyLine in propertyLines)
{
if (TryGetPropertyV1(property.Name, propertiesBlock, out var propertyBlock))
if (await ParsePropertyLine(propertyLine) is PropertyDefinition property)
{
if (PropertyIsInherited(propertyBlock))
if (TryGetPropertyV1(property.Name, propertiesBlock, out var propertyBlock))
{
continue;
if (PropertyIsInherited(propertyBlock))
{
continue;
}
result.Add(await UpdatePropertyV1(property, propertyBlock));
}
yield return UpdatePropertyV1(property, propertyBlock);
}
else if (TryGetPropertyV2(property.Name, propertiesBlock, out propertyBlock))
{
if (PropertyIsInherited(propertyBlock))
else if (TryGetPropertyV2(property.Name, propertiesBlock, out propertyBlock))
{
continue;
if (PropertyIsInherited(propertyBlock))
{
continue;
}
result.Add(await UpdatePropertyV2(property, propertyBlock));
}
else
{
ErrorReporter.Report($"No description found for property {property.Name}.", throwOrBreak: false);
}
yield return UpdatePropertyV2(property, propertyBlock);
}
else
{
ErrorReporter.Report($"No description found for property {property.Name}.", throwOrBreak: false);
}
}

return result;
}

private static bool TryGetPropertyV1(string propertyName, string propertiesBlock, [MaybeNullWhen(false)]out string propertyBlock)
Expand Down Expand Up @@ -235,11 +231,12 @@ private static bool TryGetPropertyV2(string propertyName, string propertiesBlock
private static bool PropertyIsInherited(string propertyBlock)
=> propertyBlock.Contains("This property is inherited from");

private PropertyDefinition UpdatePropertyV1(PropertyDefinition property, string propertyBlock)
private async Task<PropertyDefinition> UpdatePropertyV1(PropertyDefinition property, string propertyBlock)
{
foreach (var paragraph in ParagraphRegex.Matches(propertyBlock).Select(TrimHTML))
{
if (!ParseSubProperty(ref property, paragraph))
(var parsed, property) = await ParseSubProperty(property, paragraph);
if (!parsed)
{
property = property with { Description = paragraph };
break;
Expand All @@ -249,11 +246,11 @@ private PropertyDefinition UpdatePropertyV1(PropertyDefinition property, string
return property;
}

private PropertyDefinition UpdatePropertyV2(PropertyDefinition property, string propertyBlock)
private async Task<PropertyDefinition> UpdatePropertyV2(PropertyDefinition property, string propertyBlock)
{
foreach (var propertyDescription in PropertyRegex.Matches(propertyBlock).Select(TrimHTML))
{
ParseSubProperty(ref property, propertyDescription);
(_, property) = await ParseSubProperty(property, propertyDescription);
}

if (TrimHTML(DescriptionRegex.Match(propertyBlock)) is string description && description.Length > 0)
Expand All @@ -264,12 +261,12 @@ private PropertyDefinition UpdatePropertyV2(PropertyDefinition property, string
return property;
}

private bool ParseSubProperty(ref PropertyDefinition property, string paragraph)
private async Task<(bool, PropertyDefinition)> ParseSubProperty(PropertyDefinition property, string paragraph)
{
var colonIndex = paragraph.IndexOf(':');
if (colonIndex == -1)
{
return false;
return (false, property);
}

switch (paragraph[..colonIndex])
Expand All @@ -289,7 +286,7 @@ private bool ParseSubProperty(ref PropertyDefinition property, string paragraph)
}
else if (typeName.Contains('_'))
{
property = property with { Type = CimType.Reference, ReferencedClass = CheckClass(typeName) };
property = property with { Type = CimType.Reference, ReferencedClass = await TryLoadClass(typeName) };
}
else
{
Expand All @@ -314,7 +311,7 @@ private bool ParseSubProperty(ref PropertyDefinition property, string paragraph)
property = property with { Qualifiers = ParseQualifiers(paragraph[(colonIndex + 2)..]) };
break;
}
return true;
return (true, property);
}

private static List<QualifierDefinition> ParseQualifiers(string qualifiers)
Expand All @@ -334,7 +331,7 @@ private static string TrimHTML(string source)
.Replace("&nbsp;", " ")
.Trim();

private PropertyDefinition? ParsePropertyLine(string propertyLine)
private async Task<PropertyDefinition?> ParsePropertyLine(string propertyLine)
{
var parts = TrimHTML(propertyLine).Split(' ', options: StringSplitOptions.RemoveEmptyEntries);

Expand All @@ -347,7 +344,7 @@ private static string TrimHTML(string source)
if (!Enum.TryParse(parts[0], ignoreCase: true, out CimType type) && parts[0].Contains('_'))
{
type = CimType.Reference;
referenceType = CheckClass(parts[0]);
referenceType = await TryLoadClass(parts[0]);
}

var name = parts[^2][0] == '=' ? parts[^3] : parts[^1];
Expand Down
10 changes: 6 additions & 4 deletions Generator/Program.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// See https://aka.ms/new-console-template for more information
using System.Management.Generator;
using System.Threading.Channels;

List<string> classes =
[
Expand Down Expand Up @@ -80,7 +81,8 @@
"Win32_MethodParameterClass", "Win32_WMIElementSetting", "Win32_WMISetting",
];

var loader = new DefinitionLoader(classes);
await loader.Load();
var generator = new CodeGenerator(loader.LoadedClassDefinitions);
generator.GenerateCode();
var channel = Channel.CreateUnbounded<ClassDefinition>();

var loader = new DefinitionLoader(channel.Writer);
var generator = new CodeGenerator(channel.Reader);
await Task.WhenAll(loader.Load(classes), generator.GenerateCode());
1 change: 1 addition & 0 deletions Types/Base/_ACE.g.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* Any changes made to this file will be overwritten. *
* *
**************************************************************/
#nullable enable
namespace System.Management.Types.Base;

public partial record class _ACE(ManagementObject ManagementObject) : _SecurityRelatedClass(ManagementObject)
Expand Down
1 change: 1 addition & 0 deletions Types/Base/_AbsoluteTimerInstruction.g.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* Any changes made to this file will be overwritten. *
* *
**************************************************************/
#nullable enable
namespace System.Management.Types.Base;

public partial record class _AbsoluteTimerInstruction(ManagementObject ManagementObject) : _TimerInstruction(ManagementObject)
Expand Down
1 change: 1 addition & 0 deletions Types/Base/_AggregateEvent.g.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* Any changes made to this file will be overwritten. *
* *
**************************************************************/
#nullable enable
namespace System.Management.Types.Base;

public partial record class _AggregateEvent(ManagementObject ManagementObject) : _IndicationRelated(ManagementObject)
Expand Down
Loading