diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/__cmd_group.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/__cmd_group.py new file mode 100644 index 00000000000..df380747bdf --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/__cmd_group.py @@ -0,0 +1,23 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# +# Code generated by aaz-dev-tools +# -------------------------------------------------------------------------------------------- + +# pylint: skip-file +# flake8: noqa + +from azure.cli.core.aaz import * + + +@register_command_group( + "workload-orchestration artifcat", +) +class __CMDGroup(AAZCommandGroup): + """workload-orchestration configuration helps to manage configurations + """ + pass + + +__all__ = ["__CMDGroup"] diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/__init__.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/__init__.py new file mode 100644 index 00000000000..30e666104b3 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/__init__.py @@ -0,0 +1,13 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# +# Code generated by aaz-dev-tools +# -------------------------------------------------------------------------------------------- + +# pylint: skip-file +# flake8: noqa + +from .__cmd_group import * +from ._generate import * +from ._simple_generate import * diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/_generate.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/_generate.py new file mode 100644 index 00000000000..d3910525436 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/_generate.py @@ -0,0 +1,268 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# +# Code generated by aaz-dev-tools +# -------------------------------------------------------------------------------------------- + +# pylint: skip-file +# flake8: noqa + +from azure.cli.core.aaz import * + + +@register_command( + "workload-orchestration artifcat generate", + is_preview=False, +) +class ShowConfig(AAZCommand): + """To get a configurations available at specified hierarchical entity + :example: Show a Configuration + az workload-orchestration configuration show -g rg1 --target-name target1 --solution-template-name solutionTemplate1 + """ + + _aaz_info = { + "version": "2024-08-01-preview", + "resources": [ + ["mgmt-plane", "/subscriptions/{}/resourcegroups/{}/providers/Microsoft.Edge/solutions/{}", "2024-08-01-preview"], + ] + } + + def _handler(self, command_args): + super()._handler(command_args) + self._execute_operations() + return self._output() + + _args_schema = None + + @classmethod + def _build_arguments_schema(cls, *args, **kwargs): + if cls._args_schema is not None: + return cls._args_schema + cls._args_schema = super()._build_arguments_schema(*args, **kwargs) + + # define Arg Group "" + + _args_schema = cls._args_schema + _args_schema.resource_group = AAZResourceGroupNameArg( + required=True, + ) + _args_schema.solution_name = AAZStrArg( + options=["--solution-template-name"], + help="The name of the Solution, This is required only to get solution configurations", + # required=True, + id_part="name", + fmt=AAZStrArgFormat( + pattern="^[a-zA-Z0-9-]{3,24}$", + ), + ) + + _args_schema = cls._args_schema + _args_schema.level_name = AAZStrArg( + options=["--target-name"], + help="The Target or Site name at which values needs to be set", + + required = True, + fmt=AAZStrArgFormat( + pattern="^[a-zA-Z0-9-]{3,24}$", + ), + ) + + # define Arg Group "Resource + + # + # _args_schema.properties = AAZFreeFormDictArg( + # options=["--properties"], + # arg_group="Resource", + # help="The resource-specific properties for this resource.", + # nullable=True, + # ) + return cls._args_schema + + def _execute_operations(self): + self.pre_operations() + config_name = str(self.ctx.args.level_name) + if len(config_name) > 18: + config_name = config_name[:18] + "Config" + else: + config_name = config_name + "Config" + self.ctx.args.level_name = config_name + self.SolutionsGet(ctx=self.ctx)() + self.post_operations() + + @register_callback + def pre_operations(self): + pass + + @register_callback + def post_operations(self): + pass + + def _output(self, *args, **kwargs): + result = self.deserialize_output(self.ctx.vars.instance, client_flatten=True) + print(result["properties"]["values"]) + pass + + class SolutionsGet(AAZHttpOperation): + CLIENT_TYPE = "MgmtClient" + + def __call__(self, *args, **kwargs): + request = self.make_request() + session = self.client.send_request(request=request, stream=False, **kwargs) + if session.http_response.status_code in [200]: + return self.on_200(session) + config = dict() + config["properties"] = dict() + config["properties"]["values"] = "{}" + # # config.config = AAZStrType() + # # config.config = "[]" + if session.http_response.status_code in [404]: + self.ctx.set_var( + "instance", + config, + schema_builder=self._build_schema_on_404 + ) + # return + else: + return self.on_error(session.http_response) + + + @property + def url(self): + return self.client.format_url( + "/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.Edge/configurations/{configName}/DynamicConfigurations/{solutionName}/versions/version1", + **self.url_parameters + ) + + @property + def method(self): + return "GET" + + @property + def error_format(self): + return "MgmtErrorFormat" + + @property + def url_parameters(self): + sol_name = "common" + if has_value(self.ctx.args.solution_name): + sol_name = self.ctx.args.solution_name + + parameters = { + **self.serialize_url_param( + "resourceGroupName", self.ctx.args.resource_group, + required=True, + ), + **self.serialize_url_param( + "solutionName", sol_name, + required=True, + ), + **self.serialize_url_param( + "configName", self.ctx.args.level_name, + required=True, + ), + **self.serialize_url_param( + "subscriptionId", self.ctx.subscription_id, + required=True, + ), + } + return parameters + + @property + def query_parameters(self): + parameters = { + **self.serialize_query_param( + "api-version", "2024-06-01-preview", + required=True, + ), + } + return parameters + + @property + def header_parameters(self): + parameters = { + **self.serialize_header_param( + "Accept", "application/json", + ), + } + return parameters + + def on_200(self, session): + data = self.deserialize_http_content(session) + self.ctx.set_var( + "instance", + data, + schema_builder=self._build_schema_on_200 + ) + + _schema_on_200 = None + + @classmethod + def _build_schema_on_404(cls): + cls._schema_on_200 = AAZObjectType() + _schema_on_200 = cls._schema_on_200 + _schema_on_200.properties = AAZFreeFormDictType() + return cls._schema_on_200 + + + @classmethod + def _build_schema_on_200(cls): + if cls._schema_on_200 is not None: + return cls._schema_on_200 + + cls._schema_on_200 = AAZObjectType() + + _schema_on_200 = cls._schema_on_200 + _schema_on_200.id = AAZStrType( + flags={"read_only": True}, + ) + _schema_on_200.location = AAZStrType( + flags={"required": True}, + ) + _schema_on_200.name = AAZStrType( + flags={"read_only": True}, + ) + _schema_on_200.properties = AAZFreeFormDictType() + _schema_on_200.system_data = AAZObjectType( + serialized_name="systemData", + flags={"read_only": True}, + ) + _schema_on_200.tags = AAZDictType() + _schema_on_200.type = AAZStrType( + flags={"read_only": True}, + ) + + + + + system_data = cls._schema_on_200.system_data + system_data.created_at = AAZStrType( + serialized_name="createdAt", + ) + system_data.created_by = AAZStrType( + serialized_name="createdBy", + ) + system_data.created_by_type = AAZStrType( + serialized_name="createdByType", + ) + system_data.last_modified_at = AAZStrType( + serialized_name="lastModifiedAt", + ) + system_data.last_modified_by = AAZStrType( + serialized_name="lastModifiedBy", + ) + system_data.last_modified_by_type = AAZStrType( + serialized_name="lastModifiedByType", + ) + + tags = cls._schema_on_200.tags + tags.Element = AAZStrType() + + return cls._schema_on_200 + + +class _ShowHelper: + """Helper class for Show""" + + +__all__ = ["ShowConfig"] diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/_simple_generate.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/_simple_generate.py new file mode 100644 index 00000000000..502e1e5660a --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/_simple_generate.py @@ -0,0 +1,192 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# +# Code generated by aaz-dev-tools +# -------------------------------------------------------------------------------------------- + +# pylint: skip-file +# flake8: noqa + +import os +import sys +import asyncio +from azure.cli.core.aaz import * + +# Add wo_gen.py directory to path +SCRIPT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'wo_artifact', 'src') +sys.path.append(SCRIPT_DIR) + +from wo_gen import main as wo_gen_main + +@register_command( + "workload-orchestration artifcat schema-generate", + is_preview=False, +) +class SimpleGenerate(AAZCommand): + """Generate artifacts using the specified chart and AI configuration. + :example: Generate artifacts + az workload-orchestration artifcat simple-generate --chart-path /path/to/chart --schema-name schema --schema-version 1.0.0 --ai-endpoint url --ai-key key --ai-model model --output-dir ./output --prompt prompt.txt + """ + + _aaz_info = { + "version": "2024-08-01-preview" + } + + def _handler(self, command_args): + super()._handler(command_args) + self._execute_operations() + return self._output() + + _args_schema = None + + @classmethod + def _build_arguments_schema(cls, *args, **kwargs): + if cls._args_schema is not None: + return cls._args_schema + cls._args_schema = super()._build_arguments_schema(*args, **kwargs) + + _args_schema = cls._args_schema + _args_schema.chart_path = AAZStrArg( + options=["--chart-path"], + help="Path to the chart directory", + required=True, + ) + _args_schema.schema_name = AAZStrArg( + options=["--schema-name"], + help="Schema name", + required=True, + ) + _args_schema.schema_version = AAZStrArg( + options=["--schema-version"], + help="Schema version", + required=True, + ) + _args_schema.ai_endpoint = AAZStrArg( + options=["--ai-endpoint"], + help="AI endpoint URL", + required=True, + ) + _args_schema.ai_key = AAZStrArg( + options=["--ai-key"], + help="AI authentication key", + required=True, + ) + _args_schema.ai_model = AAZStrArg( + options=["--ai-model"], + help="AI model name", + required=True, + ) + _args_schema.output_dir = AAZStrArg( + options=["--output-dir"], + help="Output directory path", + required=True, + ) + _args_schema.prompt_file = AAZStrArg( + options=["--prompt"], + help="Path to prompt file", + required=True, + ) + + return cls._args_schema + + def _execute_operations(self): + self.pre_operations() + try: + # Extract arguments as strings + args = { + "chart_path": str(self.ctx.args.chart_path), + "schema_name": str(self.ctx.args.schema_name), + "schema_version": str(self.ctx.args.schema_version), + "ai_endpoint": str(self.ctx.args.ai_endpoint), + "ai_key": str(self.ctx.args.ai_key), + "ai_model": str(self.ctx.args.ai_model), + "output_dir": str(self.ctx.args.output_dir), + "prompt": str(self.ctx.args.prompt_file) + } + + # Set up sys.argv for wo_gen.py + sys.argv = [ + 'wo_gen.py', + args["chart_path"], + '--schema-name', args["schema_name"], + '--schema-version', args["schema_version"], + '--ai-endpoint', args["ai_endpoint"], + '--ai-key', args["ai_key"], + '--ai-model', args["ai_model"], + '--output-dir', args["output_dir"], + '--prompt', args["prompt"] + ] + + # Run wo_gen.py main function + asyncio.run(wo_gen_main()) + + # Check if output files were generated + schema_file = os.path.join(args["output_dir"], f"{args['schema_name']}-schema.yaml") + template_file = os.path.join(args["output_dir"], f"{args['schema_name']}-template.yaml") + + if os.path.exists(schema_file) and os.path.exists(template_file): + result = { + "properties": { + "status": "success", + "message": "Generation completed successfully", + "files": { + "schema": schema_file, + "template": template_file + } + } + } + else: + raise Exception("Expected output files were not generated") + + except Exception as e: + result = { + "properties": { + "status": "error", + "message": str(e), + "error": str(e), + "args": args + } + } + + self.ctx.set_var("instance", result, schema_builder=self._build_schema) + self.post_operations() + + _schema = None + + @classmethod + def _build_schema(cls): + if cls._schema is not None: + return cls._schema + + cls._schema = AAZObjectType() + _schema = cls._schema + _schema.properties = AAZObjectType() + properties = _schema.properties + properties.status = AAZStrType() + properties.message = AAZStrType() + properties.error = AAZStrType(nullable=True) + properties.files = AAZFreeFormDictType(nullable=True) + return cls._schema + + @register_callback + def pre_operations(self): + pass + + @register_callback + def post_operations(self): + pass + + def _output(self, *args, **kwargs): + result = self.deserialize_output(self.ctx.vars.instance, client_flatten=True) + if result["properties"].get("error"): + print(f"Error: {result['properties']['error']}") + else: + print(f"Success: {result['properties']['message']}") + if result["properties"].get("files"): + print("\nGenerated files:") + for file_type, file_path in result["properties"]["files"].items(): + print(f"{file_type}: {file_path}") + return result["properties"] + +__all__ = ["SimpleGenerate"] diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/README.md b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/README.md new file mode 100644 index 00000000000..2d123cda600 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/README.md @@ -0,0 +1,150 @@ +# WO Artifact Generator + +A tool to automatically generate Workload Orchestrator schemas and templates from Helm charts using Azure OpenAI for parameter analysis. + +## Features +- AI-powered parameter analysis +- Automatic parameter categorization +- Hierarchy-based configuration management +- Custom prompt support +- Nested parameter handling +- Smart validation rules + +## Prerequisites +- Python 3.8+ +- Azure OpenAI API access +- Azure CLI (for hierarchy management) +- Helm charts to analyze + +## Installation +```bash +pip install -r requirements.txt +``` + +## Usage +```bash +python src/wo_gen.py \ + --schema-name \ + --schema-version \ + --ai-endpoint \ + --ai-key \ + --ai-model \ + [--output-dir ./output] \ + [--prompt custom/prompt.txt] \ + [--verbose] +``` + +## Project Structure +``` +WOArtiGen/ +├── src/ +│ ├── ai_analyzer/ # AI analysis components +│ ├── helm_parser/ # Helm chart parsing +│ ├── schema_generator/ # Schema generation +│ ├── template_generator/ # Template generation +│ └── utils/ # Utility modules +├── prompts/ +│ ├── default/ # Default prompt templates +│ └── custom/ # Custom prompt templates +├── config/ # Configuration files +└── output/ # Generated artifacts +``` + +## Configuration + +### Hierarchy Levels +- Managed through config/hierarchy_levels.json +- Default: ['factory', 'line'] +- Auto-updates from Azure Edge contexts + +### Custom Prompts +- Place in prompts/custom/ +- Reference using --prompt argument +- Must specify guidelines for: + * Parameter configurability + * Required vs optional parameters + * Management responsibility (IT/OT) + * Hierarchy level assignment +- See prompts/custom/example_prompt.txt + +## Parameter Analysis + +### AI Response Format +For each parameter, the AI analyzes: +- configurable: Whether parameter can be modified +- required: Whether parameter must appear in template +- managed_by: Who can modify this parameter (IT/OT) +- edit_level: At which hierarchy level it can be modified + +### Management Levels + +#### IT (Information Technology) +- Security configurations +- Infrastructure settings +- Network parameters +- Compliance controls + +#### OT (Operational Technology) +- Production settings +- Performance tuning +- Operational thresholds +- Local customizations + +### Required vs Optional Parameters +Parameters marked as required will appear in the solution template. + +## Output Files + +### Schema (name-schema.yaml) +```yaml +name: schema-name +version: schema-version +rules: + configs: + parameter.name: + type: string|integer|boolean|array + required: true/false + editableAt: [hierarchy-level] + editableBy: [IT/OT] +``` + +### Template (name-template.yaml) +```yaml +schema: + name: schema-name + version: schema-version +configs: + parameter.name: ${$val(parameter.name)} +``` + +## Testing + +The project includes comprehensive unit tests for all components. Tests are located alongside their respective modules with the `test_` prefix. + +### Running Tests + +Run individual test files: +```bash +# Run specific component tests +python src/helm_parser/test_parser.py # Test chart parsing +python src/ai_analyzer/test_analyzer.py # Test AI analysis +python src/ai_analyzer/test_client.py # Test AI client +python src/schema_generator/test_generator.py # Test schema generation +python src/template_generator/test_generator.py # Test template generation +``` +### Integration Tests +Integration tests are included in `src/test_wo_gen.py` and cover: + +#### End-to-End Workflow +```bash +python src/test_wo_gen.py +``` + + +### links +### https://microsoftapc-my.sharepoint.com/personal/kup_microsoft_com/_layouts/15/stream.aspx?id=%2Fpersonal%2Fkup%5Fmicrosoft%5Fcom%2FDocuments%2FRecordings%2FRegular%20Sync%2D20250630%5F120340%2DMeeting%20Recording%2Emp4&referrer=StreamWebApp%2EWeb&referrerScenario=AddressBarCopied%2Eview%2Ede10ca09%2D93c4%2D4174%2Da75b%2De4d84b7071e6 + +##### https://microsoftapc-my.sharepoint.com/personal/kup_microsoft_com/_layouts/15/stream.aspx?id=%2Fpersonal%2Fkup%5Fmicrosoft%5Fcom%2FDocuments%2FRecordings%2FIntern%20Project%20Presentation%20%2D%20Kawalijeet%20%20Generate%20WO%20Artifacts%20using%20AI%20%5BIn%2Dperson%5D%2D20250630%5F100312%2DMeeting%20Recording%2Emp4&referrer=StreamWebApp%2EWeb&referrerScenario=AddressBarCopied%2Eview%2Ec1704787%2D5cc0%2D4ab7%2D919a%2D108729bb9665 + + +#### https://microsoftapc-my.sharepoint.com/:p:/g/personal/t-kawsingh_microsoft_com/EcSJzE3h2NNPofXMEVSXJfcBC8OIdy3V8usnDDCgjxBotA?wdOrigin=TEAMS-MAGLEV.p2p_ns.rwc&wdExp=TEAMS-TREATMENT&wdhostclicktime=1752024095271&web=1 \ No newline at end of file diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/charts/sample-chart/Chart.yaml b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/charts/sample-chart/Chart.yaml new file mode 100644 index 00000000000..e231a8271ba --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/charts/sample-chart/Chart.yaml @@ -0,0 +1,5 @@ +apiVersion: v2 +name: sample-chart +description: A sample Helm chart for testing wo_gen.py +version: 0.1.0 +appVersion: "1.0.0" diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/charts/sample-chart/templates/_helpers.tpl b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/charts/sample-chart/templates/_helpers.tpl new file mode 100644 index 00000000000..f4e4b3c107b --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/charts/sample-chart/templates/_helpers.tpl @@ -0,0 +1,49 @@ +{{/* +Expand the name of the chart. +*/}} +{{- define "sample-chart.name" -}} +{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Create a default fully qualified app name. +*/}} +{{- define "sample-chart.fullname" -}} +{{- if .Values.fullnameOverride }} +{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- $name := default .Chart.Name .Values.nameOverride }} +{{- if contains $name .Release.Name }} +{{- .Release.Name | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} +{{- end }} +{{- end }} +{{- end }} + +{{/* +Create chart name and version as used by the chart label. +*/}} +{{- define "sample-chart.chart" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Common labels +*/}} +{{- define "sample-chart.labels" -}} +helm.sh/chart: {{ include "sample-chart.chart" . }} +{{ include "sample-chart.selectorLabels" . }} +{{- if .Chart.AppVersion }} +app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} +{{- end }} +app.kubernetes.io/managed-by: {{ .Release.Service }} +{{- end }} + +{{/* +Selector labels +*/}} +{{- define "sample-chart.selectorLabels" -}} +app.kubernetes.io/name: {{ include "sample-chart.name" . }} +app.kubernetes.io/instance: {{ .Release.Name }} +{{- end }} diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/charts/sample-chart/templates/deployment.yaml b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/charts/sample-chart/templates/deployment.yaml new file mode 100644 index 00000000000..1e014544284 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/charts/sample-chart/templates/deployment.yaml @@ -0,0 +1,35 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "sample-chart.fullname" . }} + labels: + {{- include "sample-chart.labels" . | nindent 4 }} +spec: + replicas: {{ .Values.replicaCount }} + selector: + matchLabels: + {{- include "sample-chart.selectorLabels" . | nindent 6 }} + template: + metadata: + labels: + {{- include "sample-chart.selectorLabels" . | nindent 8 }} + spec: + containers: + - name: {{ .Chart.Name }} + image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" + imagePullPolicy: {{ .Values.image.pullPolicy }} + ports: + - name: http + containerPort: 80 + protocol: TCP + resources: + {{- toYaml .Values.resources | nindent 12 }} + env: + - name: CONFIG_NAME + value: {{ .Values.config.name }} + - name: CONFIG_VERSION + value: {{ .Values.config.version }} + - name: FEATURE1_ENABLED + value: "{{ .Values.config.settings.feature1 }}" + - name: FEATURE2_ENABLED + value: "{{ .Values.config.settings.feature2 }}" diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/charts/sample-chart/values.yaml b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/charts/sample-chart/values.yaml new file mode 100644 index 00000000000..63cfd1911e9 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/charts/sample-chart/values.yaml @@ -0,0 +1,36 @@ +# Default values for sample-chart +replicaCount: 1 + +image: + repository: nginx + pullPolicy: IfNotPresent + tag: "" + +nameOverride: "" +fullnameOverride: "" + +service: + type: ClusterIP + port: 80 + +resources: + limits: + cpu: 100m + memory: 128Mi + requests: + cpu: 100m + memory: 128Mi + +# Custom configuration section for wo_gen.py testing +config: + name: "default-config" + version: "1.0.0" + settings: + enabled: true + feature1: true + feature2: false + endpoints: + - name: "primary" + url: "http://example.com" + - name: "secondary" + url: "http://backup.example.com" diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/config/hierarchy_levels.json b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/config/hierarchy_levels.json new file mode 100644 index 00000000000..eb1ea3af4a5 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/config/hierarchy_levels.json @@ -0,0 +1,21 @@ +{ + "levels": [ + { + "description": "Country", + "name": "country" + }, + { + "description": "Region", + "name": "region" + }, + { + "description": "Factory", + "name": "factory" + }, + { + "description": "Line", + "name": "line" + } + ], + "last_updated": "2025-07-04T06:36:41.419450" +} \ No newline at end of file diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/prompts/custom/prompt.txt b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/prompts/custom/prompt.txt new file mode 100644 index 00000000000..5910645b5ef --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/prompts/custom/prompt.txt @@ -0,0 +1,16 @@ +You are a Helm chart analyzer. Given a Helm chart, you will: +1. Review the Chart.yaml, values.yaml, and template files +2. Analyze the configuration structure and dependencies +3. Identify potential issues or improvements +4. Consider security best practices +5. Validate the consistency between values and their usage in templates + +Please provide your analysis in a clear, structured format with specific examples and recommendations where applicable. + +Base your analysis on: +- Chart metadata and version information +- Default values and their implications +- Template structure and usage patterns +- Configuration flexibility and extensibility +- Security considerations +- Best practices compliance diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/prompts/custom/test_prompt.txt b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/prompts/custom/test_prompt.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/prompts/default/prompt.txt b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/prompts/default/prompt.txt new file mode 100644 index 00000000000..ffe8d9cd52a --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/prompts/default/prompt.txt @@ -0,0 +1,17 @@ +For each parameter, consider: + +Configurability: Should users be able to modify this in production? Make it configurable if it needs runtime changes or environment-specific values. Keep it fixed for security settings, system IDs, or compliance requirements. + +Management: Is this an IT or OT concern? +- IT handles security, infrastructure, and compliance +- OT manages daily operations and production settings + +Hierarchy: Where should this be configured? Higher levels for broad impact settings and security. Lower levels for local operations and tuning. + +Required Status: Does this need a value in the template? Mark as required if it: +- Has security/behavioral impact +- Lacks a default value +- Is critical for operation +- Has compliance requirements + +Think about these in the context of a production environment where both security and operational flexibility are important. diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/requirements.txt b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/requirements.txt new file mode 100644 index 00000000000..64e7fc492dd --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/requirements.txt @@ -0,0 +1,7 @@ +PyYAML>=6.0 +click>=8.0 +openai>=1.0.0 +tiktoken>=0.5.0 +aiohttp>=3.8.0 +asyncio>=3.4.3 +python-dotenv>=1.0.0 diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/__init__.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/__init__.py new file mode 100644 index 00000000000..d74e273f1de --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/__init__.py @@ -0,0 +1 @@ +"""WO Artifact Generator package""" diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/__init__.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/__init__.py new file mode 100644 index 00000000000..b5a2ea9eebf --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/__init__.py @@ -0,0 +1,8 @@ +""" +AI-based parameter analyzer module +""" + +from .analyzer import AIParameterAnalyzer +from .client import AzureOpenAIClient + +__all__ = ['AIParameterAnalyzer', 'AzureOpenAIClient'] diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/analyzer.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/analyzer.py new file mode 100644 index 00000000000..adb4e6e0996 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/analyzer.py @@ -0,0 +1,153 @@ +""" +AI-based parameter analyzer implementation +""" + +from dataclasses import asdict +from typing import Dict, List, Any, Optional +import json +from utils.logger import LoggerMixin +from helm_parser.parser import ChartData, ChartParameter +from .client import AzureOpenAIClient + +class AIParameterAnalyzer(LoggerMixin): + """Analyzes Helm chart parameters using Azure OpenAI""" + + def __init__(self, endpoint: str, api_key: str, deployment: str, + custom_prompt: Optional[str] = None, + hierarchy_levels: Optional[List[str]] = None): + """ + Initialize the analyzer. + + Args: + endpoint: Azure OpenAI endpoint URL + api_key: Azure OpenAI API key + deployment: Model deployment name + custom_prompt: Optional custom system prompt + hierarchy_levels: Optional list of hierarchy levels + """ + super().__init__() + self.ai_client = AzureOpenAIClient( + endpoint=endpoint, + api_key=api_key, + deployment=deployment, + hierarchy_levels=hierarchy_levels + ) + self.system_prompt = custom_prompt + self.hierarchy_levels = hierarchy_levels or ['factory', 'line'] + + def _build_chart_context(self, chart_data: ChartData) -> Dict[str, Any]: + """ + Build context information about the chart. + + Args: + chart_data: Parsed chart data + + Returns: + Dictionary containing chart context + """ + return { + "name": chart_data.name, + "version": chart_data.version, + "description": chart_data.description, + "dependencies": [dep.get('name') for dep in chart_data.dependencies], + "application_type": "web_server", + "deployment_type": "kubernetes", + "target_environment": "production", + "hierarchy_levels": self.hierarchy_levels + } + + def _format_parameter(self, param: ChartParameter) -> Dict[str, Any]: + """ + Format parameter data for AI analysis. + + Args: + param: Chart parameter + + Returns: + Dictionary containing formatted parameter data + """ + return { + "name": param.name, + "type": param.type, + "default_value": param.default_value, + "description": param.description, + "path": ".".join(param.nested_path) if param.nested_path else param.name + } + + async def analyze_parameters(self, chart_data: ChartData) -> Dict[str, Dict[str, Any]]: + """ + Analyze chart parameters using Azure OpenAI. + + Args: + chart_data: Parsed chart data + + Returns: + Dictionary mapping parameter names to their analysis results + """ + # Format all parameters for analysis + formatted_params = [ + self._format_parameter(param) + for param in chart_data.parameters.values() + ] + + self.logger.info(f"Analyzing {len(formatted_params)} parameters") + + # Build context for AI + chart_context = self._build_chart_context(chart_data) + + try: + # Build prompt from template + prompt = self.system_prompt.replace('{chart_context}', json.dumps(chart_context, indent=2)) + + # Get AI analysis + analysis_results = await self.ai_client.analyze_parameters( + formatted_params, + prompt + ) + + # Process and validate results, filtering out non-essential parameters + validated_results = {} + for param_name, result in analysis_results.items(): + if (result.get('configurable', False)): # Only include configurable parameters + validated_results[param_name] = result + else: + self.logger.debug(f"Filtered out {param_name}: non-essential or invalid") + + self.logger.info(f"AI filtering: {len(analysis_results)} -> {len(validated_results)} parameters") + return validated_results + + except Exception as e: + self.logger.error(f"Parameter analysis failed: {str(e)}") + raise + + def get_analysis_stats(self, results: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: + """ + Generate statistics about the analysis results. + + Args: + results: Analysis results + + Returns: + Dictionary containing analysis statistics + """ + stats = { + 'total_parameters': len(results), + 'configurable': 0, + 'required': 0, + 'it_managed': 0, + 'ot_managed': 0, + 'hierarchy_levels': {level: 0 for level in self.hierarchy_levels}, + } + + for result in results.values(): + if result['configurable']: + stats['configurable'] += 1 + if result['required']: + stats['required'] += 1 + if result['managed_by'] == 'IT': + stats['it_managed'] += 1 + if result['managed_by'] == 'OT': + stats['ot_managed'] += 1 + stats['hierarchy_levels'][result['edit_level']] += 1 + + return stats diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/client.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/client.py new file mode 100644 index 00000000000..7e88c91e16c --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/client.py @@ -0,0 +1,278 @@ +""" +Azure OpenAI API client implementation +""" + +import openai +from openai import AsyncAzureOpenAI +from typing import Dict, List, Any, Optional +import tiktoken +import asyncio +import json +import re +from utils.logger import LoggerMixin + +class AzureOpenAIClient(LoggerMixin): + """Client for interacting with Azure OpenAI API""" + + GPT4_MAX_TOKENS = 8192 # GPT-4 context window + GPT35_MAX_TOKENS = 4096 # GPT-3.5 context window + RESPONSE_TOKENS = 1000 # Reserve tokens for response + BATCH_SIZE = 6 # Process 6 parameters per batch + + def __init__(self, endpoint: str, api_key: str, deployment: str, + hierarchy_levels: Optional[List[str]] = None): + """ + Initialize the Azure OpenAI client. + + Args: + endpoint: Azure OpenAI endpoint URL + api_key: Azure OpenAI API key + deployment: Model deployment name + hierarchy_levels: Optional list of hierarchy levels + """ + super().__init__() + self.client = AsyncAzureOpenAI( + api_key=api_key, + api_version="2023-05-15", + azure_endpoint=endpoint + ) + self.deployment = deployment + self.hierarchy_levels = hierarchy_levels or ['factory', 'line'] + + # Set model-specific configurations + self.is_gpt4 = "gpt-4" in deployment.lower() or "gpt4" in deployment.lower() + self.max_tokens = self.GPT4_MAX_TOKENS if self.is_gpt4 else self.GPT35_MAX_TOKENS + self.available_tokens = self.max_tokens - self.RESPONSE_TOKENS + self.encoding = tiktoken.encoding_for_model("gpt-4" if self.is_gpt4 else "gpt-3.5-turbo") + + def _count_tokens(self, text: str) -> int: + """Count tokens in a text string""" + tokens = self.encoding.encode(text) + return len(tokens) + + def _create_batches(self, items: List[Dict[str, Any]], + system_prompt: str) -> List[List[Dict[str, Any]]]: + """Create batches of fixed size while respecting token limits""" + batches = [] + current_batch = [] + current_tokens = self._count_tokens(system_prompt) + + for item in items: + item_tokens = self._count_tokens(json.dumps(item)) + if len(current_batch) >= self.BATCH_SIZE or current_tokens + item_tokens > self.available_tokens: + if current_batch: + batches.append(current_batch) + current_batch = [item] + current_tokens = self._count_tokens(system_prompt) + item_tokens + else: + current_batch.append(item) + current_tokens += item_tokens + + if current_batch: + batches.append(current_batch) + + return batches + + async def analyze_parameters(self, parameters: List[Dict[str, Any]], + system_prompt: str, + max_retries: int = 3) -> Dict[str, Dict[str, Any]]: + """Analyze parameters using Azure OpenAI""" + results = {} + batches = self._create_batches(parameters, system_prompt) + + # Create response template with current hierarchy levels + response_template = self._create_response_template() + + for batch_idx, batch in enumerate(batches): + self.logger.info(f"Processing batch {batch_idx + 1} of {len(batches)}") + + # Format batch parameters for prompt + param_text = json.dumps(batch, indent=2) + prompt = ( + f"{system_prompt}\n\n" + f"{response_template}\n\n" + f"Parameters to analyze (use exact names):\n{param_text}" + ) + + for attempt in range(max_retries): + try: + messages = [ + {"role": "system", "content": "You are a specialized Kubernetes Configuration Analyzer. You must use exact parameter names and follow the format strictly."}, + {"role": "user", "content": prompt} + ] + + response = await self.client.chat.completions.create( + model=self.deployment, + messages=messages, + temperature=0.0, + max_tokens=self.RESPONSE_TOKENS, + n=1, + response_format={"type": "json_object"} + ) + + # Get response content + result_text = response.choices[0].message.content + + # Parse and validate response + batch_results = self._parse_response(result_text) + if batch_results: + results.update(batch_results) + break + elif attempt < max_retries - 1: + # Retry with more explicit instructions + prompt = self._add_error_context(prompt, result_text) + await asyncio.sleep(2 ** attempt) + + except Exception as e: + self.logger.error(f"Error processing batch {batch_idx + 1}: {str(e)}") + if attempt == max_retries - 1: + self.logger.warning(f"Failed to process batch after {max_retries} attempts") + else: + await asyncio.sleep(2 ** attempt) # Exponential backoff + + return results + + def _create_response_template(self) -> str: + """Create response template with current hierarchy levels""" + levels_str = "|".join([f'"{level}"' for level in self.hierarchy_levels]) + return f""" +You must return a JSON object with parameter analysis results. Each parameter name should use the exact full path from the input. + +Required Format: +{{ + "parameter_name": {{ + "configurable": true/false, + "managed_by": "IT/OT", + "edit_level": {levels_str}, + "required": true/false + }} +}} + +Rules: +1. Use EXACT parameter names from input +2. All fields are required for each parameter +3. managed_by must be "IT" or "OT" +4. edit_level must be one of: {self.hierarchy_levels} +5. required must be true/false + +Configurable Parameter Guidelines: +- Set to true if parameter should be modifiable in production +- Set to false for fixed and static system configurations +- Consider the following factors: + * Runtime modifiability needs + * Operational flexibility requirements + * System stability impact + * Security implications + * Compliance requirements + +Parameters typically configurable: +- Resource allocations (memory, CPU) +- Connection settings (ports, endpoints) +- Performance tuning parameters +- Operational thresholds +- Environment-specific values + +Parameters typically not configurable: +- Core security settings +- System identifiers +- Protocol versions +- Fixed architectural components +- Compliance-mandated values + +Hierarchy (edit_level) Understanding: +- Parameters at higher levels affect all environments below them +- Changes at higher levels have broader organizational impact +- Lower level parameters are more specific to local environments +- Consider the scope of impact when determining hierarchy level +- Parameters that affect multiple environments should be managed higher +- Local customizations should be allowed at appropriate levels +- Critical security and compliance settings belong at higher levels +- Operational parameters typically belong at levels closer to usage + +IT (Information Technology) Context: +- Manages enterprise-wide security and infrastructure +- Handles authentication, certificates, and security policies +- Controls infrastructure configurations and networking +- Responsible for system-wide monitoring and compliance + +OT (Operational Technology) Context: +- Manages factory-specific operational parameters +- Controls production-related configurations +- Handles day-to-day operational adjustments +- Responsible for local performance optimization + +Required Field Guidelines: +- Set to true if parameter must be in solution template +- Set to false if parameter can be omitted from template""" + + def _add_error_context(self, prompt: str, failed_response: str) -> str: + """Add error context to prompt for retry attempts""" + error_context = ( + "\nPrevious response was invalid. Common issues found:\n" + "1. Parameter names must match input exactly\n" + "2. Each parameter must have all required fields\n" + "3. managed_by must be exactly 'IT' or 'OT'\n" + f"4. edit_level must be one of: {self.hierarchy_levels}\n" + f"\nInvalid response was:\n{failed_response}\n\n" + "Try again with the EXACT parameter names from the input." + ) + return f"{prompt}\n{error_context}" + + def _parse_response(self, response: str) -> Optional[Dict[str, Dict[str, Any]]]: + """Parse and validate the API response""" + try: + # Find JSON content (handle cases where there might be additional text) + start = response.find('{') + end = response.rfind('}') + 1 + if start == -1 or end == 0: + self.logger.error("No JSON content found in response") + return None + + json_str = response[start:end] + + # Try to parse JSON + result = json.loads(json_str) + + # Validate structure + if not isinstance(result, dict): + self.logger.error("Response is not a dictionary") + return None + + # Validate each parameter result + validated = {} + for param_name, param_data in result.items(): + if self._validate_parameter_result(param_data): + validated[param_name] = param_data + else: + self.logger.warning(f"Invalid result format for parameter {param_name}") + self.logger.warning(f"Invalid data: {param_data}") + + return validated if validated else None + + except json.JSONDecodeError as e: + self.logger.error(f"Failed to parse JSON response: {str(e)}") + self.logger.error(f"Invalid JSON: {json_str}") + return None + except Exception as e: + self.logger.error(f"Error processing response: {str(e)}") + return None + + def _validate_parameter_result(self, result: Any) -> bool: + """Validate the structure of a parameter result""" + try: + required_fields = { + 'configurable': lambda x: isinstance(x, bool), + 'managed_by': lambda x: x in ('IT', 'OT'), + 'edit_level': lambda x: x in self.hierarchy_levels, + 'required': lambda x: isinstance(x, bool) + } + + return ( + isinstance(result, dict) and + all(field in result for field in required_fields) and + all(check(result[field]) for field, check in required_fields.items()) + ) + + except Exception as e: + self.logger.error(f"Validation error: {str(e)}") + return False diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/test_analyzer.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/test_analyzer.py new file mode 100644 index 00000000000..fda479ef1ff --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/test_analyzer.py @@ -0,0 +1,184 @@ +import unittest +import os +import sys +from unittest.mock import patch, Mock, AsyncMock +from typing import Dict, Any + +# Add src directory to Python path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))) + +from ai_analyzer.analyzer import AIParameterAnalyzer +from helm_parser.parser import ChartData, ChartParameter + +class TestAIParameterAnalyzer(unittest.IsolatedAsyncioTestCase): # Changed to IsolatedAsyncioTestCase + async def asyncSetUp(self): # Changed to asyncSetUp + """Set up test environment""" + self.endpoint = "https://test.openai.azure.com" + self.api_key = "test-key" + self.deployment = "gpt-4" + self.custom_prompt = "Test prompt with {chart_context}" + self.hierarchy_levels = ["factory", "line", "machine"] + + self.analyzer = AIParameterAnalyzer( + endpoint=self.endpoint, + api_key=self.api_key, + deployment=self.deployment, + custom_prompt=self.custom_prompt, + hierarchy_levels=self.hierarchy_levels + ) + + # Sample chart data + self.chart_data = ChartData( + name="test-chart", + version="1.0.0", + description="Test description", + parameters={ + "param1": ChartParameter( + name="param1", + type="string", + default_value="value1", + description="Parameter 1" + ), + "param2": ChartParameter( + name="param2", + type="int", + default_value=42, + description="Parameter 2", + nested_path=["nested", "param2"] + ) + }, + dependencies=[{"name": "dep1"}] + ) + + async def test_init(self): # Changed to async + """Test analyzer initialization""" + self.assertEqual(self.analyzer.system_prompt, self.custom_prompt) + self.assertEqual(self.analyzer.hierarchy_levels, self.hierarchy_levels) + + # Test default hierarchy levels + default_analyzer = AIParameterAnalyzer( + endpoint=self.endpoint, + api_key=self.api_key, + deployment=self.deployment + ) + self.assertEqual(default_analyzer.hierarchy_levels, ['factory', 'line']) + + async def test_build_chart_context(self): # Changed to async + """Test chart context building""" + context = self.analyzer._build_chart_context(self.chart_data) + + self.assertEqual(context["name"], "test-chart") + self.assertEqual(context["version"], "1.0.0") + self.assertEqual(context["description"], "Test description") + self.assertEqual(context["dependencies"], ["dep1"]) + self.assertEqual(context["application_type"], "web_server") + self.assertEqual(context["deployment_type"], "kubernetes") + self.assertEqual(context["target_environment"], "production") + self.assertEqual(context["hierarchy_levels"], self.hierarchy_levels) + + async def test_format_parameter(self): # Changed to async + """Test parameter formatting""" + # Test simple parameter + param1 = self.chart_data.parameters["param1"] + formatted1 = self.analyzer._format_parameter(param1) + + self.assertEqual(formatted1["name"], "param1") + self.assertEqual(formatted1["type"], "string") + self.assertEqual(formatted1["default_value"], "value1") + self.assertEqual(formatted1["description"], "Parameter 1") + self.assertEqual(formatted1["path"], "param1") + + # Test nested parameter + param2 = self.chart_data.parameters["param2"] + formatted2 = self.analyzer._format_parameter(param2) + + self.assertEqual(formatted2["name"], "param2") + self.assertEqual(formatted2["path"], "nested.param2") + + @patch('ai_analyzer.analyzer.AIParameterAnalyzer._build_chart_context') + @patch('ai_analyzer.client.AzureOpenAIClient.analyze_parameters') + async def test_analyze_parameters_success(self, mock_analyze, mock_context): + """Test successful parameter analysis""" + # Mock responses + mock_context.return_value = {"test": "context"} + mock_analyze.return_value = { + "param1": { + "configurable": True, + "required": True, + "managed_by": "IT", + "edit_level": "factory" + }, + "param2": { + "configurable": False, + "required": False, + "managed_by": "OT", + "edit_level": "line" + } + } + + results = await self.analyzer.analyze_parameters(self.chart_data) + + # Only configurable parameters should be included + self.assertEqual(len(results), 1) + self.assertIn("param1", results) + self.assertNotIn("param2", results) + + @patch('ai_analyzer.client.AzureOpenAIClient.analyze_parameters') + async def test_analyze_parameters_error(self, mock_analyze): + """Test error handling in parameter analysis""" + mock_analyze.side_effect = Exception("API Error") + + with self.assertRaises(Exception): + await self.analyzer.analyze_parameters(self.chart_data) + + async def test_get_analysis_stats(self): # Changed to async + """Test analysis statistics generation""" + results = { + "param1": { + "configurable": True, + "required": True, + "managed_by": "IT", + "edit_level": "factory" + }, + "param2": { + "configurable": True, + "required": False, + "managed_by": "OT", + "edit_level": "line" + }, + "param3": { + "configurable": True, + "required": True, + "managed_by": "OT", + "edit_level": "machine" + } + } + + stats = self.analyzer.get_analysis_stats(results) + + self.assertEqual(stats["total_parameters"], 3) + self.assertEqual(stats["configurable"], 3) + self.assertEqual(stats["required"], 2) + self.assertEqual(stats["it_managed"], 1) + self.assertEqual(stats["ot_managed"], 2) + self.assertEqual(stats["hierarchy_levels"]["factory"], 1) + self.assertEqual(stats["hierarchy_levels"]["line"], 1) + self.assertEqual(stats["hierarchy_levels"]["machine"], 1) + + @patch('ai_analyzer.client.AzureOpenAIClient.analyze_parameters') + async def test_analyze_empty_parameters(self, mock_analyze): + """Test analysis with no parameters""" + empty_chart = ChartData( + name="empty-chart", + version="1.0.0", + description=None, + parameters={}, + dependencies=[] + ) + + mock_analyze.return_value = {} + results = await self.analyzer.analyze_parameters(empty_chart) + self.assertEqual(len(results), 0) + +if __name__ == '__main__': + unittest.main() diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/test_client.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/test_client.py new file mode 100644 index 00000000000..d1ebfff5209 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/ai_analyzer/test_client.py @@ -0,0 +1,193 @@ +import unittest +import os +import sys +import json +from unittest.mock import patch, Mock, AsyncMock +from typing import Dict, Any + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))) +from ai_analyzer.client import AzureOpenAIClient + +class TestAzureOpenAIClient(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + """Set up test environment""" + self.endpoint = "https://test.openai.azure.com" + self.api_key = "test-key" + self.deployment = "gpt-4" + self.hierarchy_levels = ["factory", "line", "machine"] + + # Create a mock client instance + self.mock_openai_client = AsyncMock() + + with patch('ai_analyzer.client.AsyncAzureOpenAI') as mock_azure: + mock_azure.return_value = self.mock_openai_client + self.client = AzureOpenAIClient( + endpoint=self.endpoint, + api_key=self.api_key, + deployment=self.deployment, + hierarchy_levels=self.hierarchy_levels + ) + + # Sample parameters for testing + self.test_parameters = [ + { + "name": "param1", + "type": "string", + "default_value": "value1", + "description": "Parameter 1", + "path": "param1" + } + ] + + self.test_prompt = "Test prompt" + + # Valid response format + self.valid_response = { + "param1": { + "configurable": True, + "managed_by": "IT", + "edit_level": "factory", + "required": True + } + } + + def create_mock_response(self, content): + """Helper method to create mock response""" + response = Mock() + if isinstance(content, str): + response.choices = [Mock(message=Mock(content=content))] + else: + response.choices = [Mock(message=Mock(content=json.dumps(content)))] + return response + + async def test_analyze_parameters_success(self): + """Test successful parameter analysis""" + # Set up mock response + mock_response = self.create_mock_response(self.valid_response) + self.mock_openai_client.chat.completions.create.return_value = mock_response + + # Test successful analysis + results = await self.client.analyze_parameters( + self.test_parameters, + self.test_prompt + ) + + # Verify results + self.assertEqual(len(results), 1) + self.assertIn("param1", results) + self.assertTrue(results["param1"]["configurable"]) + self.assertEqual(results["param1"]["managed_by"], "IT") + self.assertEqual(results["param1"]["edit_level"], "factory") + self.assertTrue(results["param1"]["required"]) + + # Verify API call + self.mock_openai_client.chat.completions.create.assert_called_once() + call_args = self.mock_openai_client.chat.completions.create.call_args + self.assertEqual(call_args[1]["model"], self.deployment) + self.assertEqual(call_args[1]["temperature"], 0.0) + + async def test_analyze_parameters_invalid_response(self): + """Test handling of invalid API responses""" + test_cases = [ + # Invalid JSON + ("invalid json", {}), + + # Missing required fields + ({"param1": {"configurable": True}}, {}), + + # Invalid field values + ({"param1": { + "configurable": True, + "managed_by": "INVALID", + "edit_level": "factory", + "required": True + }}, {}), + + # Empty response + ({}, {}) + ] + + for content, expected in test_cases: + mock_response = self.create_mock_response(content) + self.mock_openai_client.chat.completions.create.return_value = mock_response + + results = await self.client.analyze_parameters( + self.test_parameters, + self.test_prompt + ) + self.assertEqual(results, expected) + + async def test_analyze_parameters_retries(self): + """Test retry mechanism""" + error_response = AsyncMock(side_effect=Exception("API Error")) + success_response = self.create_mock_response(self.valid_response) + + self.mock_openai_client.chat.completions.create.side_effect = [ + error_response, + success_response + ] + + results = await self.client.analyze_parameters( + self.test_parameters, + self.test_prompt, + max_retries=2 + ) + + self.assertIn("param1", results) + self.assertEqual(self.mock_openai_client.chat.completions.create.call_count, 2) + + def test_validate_parameter_result(self): + """Test parameter validation""" + test_cases = [ + # Valid case + (self.valid_response["param1"], True), + + # Missing field + ({"managed_by": "IT", "edit_level": "factory", "required": True}, False), + + # Invalid managed_by + ({ + "configurable": True, + "managed_by": "INVALID", + "edit_level": "factory", + "required": True + }, False), + + # Invalid edit_level + ({ + "configurable": True, + "managed_by": "IT", + "edit_level": "invalid", + "required": True + }, False), + + # Invalid type + ({ + "configurable": "true", # Should be boolean + "managed_by": "IT", + "edit_level": "factory", + "required": True + }, False) + ] + + for test_input, expected in test_cases: + result = self.client._validate_parameter_result(test_input) + self.assertEqual(result, expected) + + def test_create_response_template(self): + """Test response template generation""" + template = self.client._create_response_template() + + # Check template contains all required sections + self.assertIn("Required Format:", template) + self.assertIn("configurable", template) + self.assertIn("managed_by", template) + self.assertIn("edit_level", template) + self.assertIn("required", template) + + # Check hierarchy levels are included + for level in self.hierarchy_levels: + self.assertIn(level, template) + +if __name__ == '__main__': + unittest.main() diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/helm_parser/__init__.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/helm_parser/__init__.py new file mode 100644 index 00000000000..514682c0ade --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/helm_parser/__init__.py @@ -0,0 +1,7 @@ +""" +Helm chart parsing module initialization +""" + +from .parser import HelmChartParser as HelmParser, ChartData, ChartParameter + +__all__ = ['HelmParser', 'ChartData', 'ChartParameter'] diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/helm_parser/parser.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/helm_parser/parser.py new file mode 100644 index 00000000000..6a0b41e45df --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/helm_parser/parser.py @@ -0,0 +1,175 @@ +""" +Helm chart parsing module for WO Artifact Generator +""" + +import os +import yaml +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, List +from dataclasses import dataclass +from utils.logger import LoggerMixin + +@dataclass +class ChartParameter: + """Data class representing a Helm chart parameter""" + name: str + type: str + default_value: Optional[Any] = None + description: Optional[str] = None + required: bool = False + nested_path: Optional[List[str]] = None + +@dataclass +class ChartData: + """Data class representing parsed Helm chart data""" + name: str + version: str + description: Optional[str] + parameters: Dict[str, ChartParameter] + dependencies: List[Dict[str, Any]] + +class BaseChartParser(ABC, LoggerMixin): + """Abstract base class for Helm chart parsing""" + + def __init__(self, chart_path: str) -> None: + """ + Initialize the chart parser. + + Args: + chart_path: Path to the Helm chart directory + """ + super().__init__() + self.chart_path = chart_path + self.values_file = os.path.join(chart_path, 'values.yaml') + self.chart_file = os.path.join(chart_path, 'Chart.yaml') + + @abstractmethod + def parse(self) -> ChartData: + """ + Parse the Helm chart and extract relevant data. + + Returns: + ChartData object containing parsed information + """ + pass + + def _read_yaml(self, file_path: str) -> Dict[str, Any]: + """ + Read and parse a YAML file. + + Args: + file_path: Path to the YAML file + + Returns: + Dictionary containing parsed YAML data + + Raises: + FileNotFoundError: If the file doesn't exist + yaml.YAMLError: If the file is not valid YAML + """ + try: + if not os.path.exists(file_path): + self.logger.warning(f"File not found: {file_path}") + return {} + + with open(file_path, 'r') as f: + return yaml.safe_load(f) or {} + + except yaml.YAMLError as e: + self.logger.error(f"Error parsing YAML file {file_path}: {str(e)}") + raise + +class HelmChartParser(BaseChartParser): + """Implementation of Helm chart parser""" + + def _extract_parameters(self, data: Dict[str, Any], path: List[str] = None) -> Dict[str, ChartParameter]: + """ + Recursively extract parameters from values.yaml + + Args: + data: Dictionary containing values data + path: Current path in nested structure + + Returns: + Dictionary mapping parameter names to ChartParameter objects + """ + if path is None: + path = [] + + parameters = {} + + for key, value in data.items(): + current_path = path + [key] + + if isinstance(value, dict): + # Recursively process nested dictionaries + nested_params = self._extract_parameters(value, current_path) + parameters.update(nested_params) + else: + # Create parameter entry + param_name = '.'.join(current_path) + param_type = self._infer_type(value) + + parameters[param_name] = ChartParameter( + name=param_name, + type=param_type, + default_value=value, + nested_path=current_path, + required=False # Will be updated by schema generator + ) + + return parameters + + def _infer_type(self, value: Any) -> str: + """ + Infer the type of a parameter value. + + Args: + value: Parameter value + + Returns: + String representing the parameter type + """ + if isinstance(value, bool): + return 'boolean' + elif isinstance(value, int): + return 'int' + elif isinstance(value, float): + return 'float' + elif isinstance(value, list): + if value: + element_type = self._infer_type(value[0]) + return f'array[{element_type}]' + return 'array[string]' # Default to string array if empty + else: + return 'string' + + def parse(self) -> ChartData: + """ + Parse the Helm chart and extract relevant data. + + Returns: + ChartData object containing parsed information + + Raises: + ValueError: If Chart.yaml is missing or invalid + """ + # Read Chart.yaml + chart_info = self._read_yaml(self.chart_file) + if not chart_info: + raise ValueError(f"Invalid or missing Chart.yaml in {self.chart_path}") + + # Read values.yaml (may be empty) + values = self._read_yaml(self.values_file) + + # Extract parameters + parameters = self._extract_parameters(values) + + # Create ChartData object + return ChartData( + name=chart_info.get('name', ''), + version=chart_info.get('version', ''), + description=chart_info.get('description'), + parameters=parameters, + dependencies=chart_info.get('dependencies', []) + ) diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/helm_parser/test_parser.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/helm_parser/test_parser.py new file mode 100644 index 00000000000..488ee3edefb --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/helm_parser/test_parser.py @@ -0,0 +1,209 @@ +import unittest +import os +import sys +import yaml +from unittest.mock import patch, mock_open + +# Add src directory to Python path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))) +from helm_parser.parser import HelmChartParser, ChartParameter, ChartData + +class TestHelmChartParser(unittest.TestCase): + def setUp(self): + """Set up test environment""" + self.test_chart_path = "test_chart" + self.chart_yaml = { + "name": "test-chart", + "version": "1.0.0", + "description": "Test chart description", + "dependencies": [ + {"name": "dep1", "version": "1.0.0"}, + {"name": "dep2", "version": "2.0.0"} + ] + } + + self.values_yaml = { + "simple": "value", + "boolean": True, + "number": 42, + "float_num": 3.14, + "array": ["item1", "item2"], + "nested": { + "param1": True, + "param2": 100, + "deep": { + "param3": "value3" + } + } + } + + self.parser = HelmChartParser(self.test_chart_path) + + def test_infer_type_boolean(self): + """Test type inference for boolean values""" + self.assertEqual(self.parser._infer_type(True), 'boolean') + self.assertEqual(self.parser._infer_type(False), 'boolean') + + def test_infer_type_integer(self): + """Test type inference for integer values""" + self.assertEqual(self.parser._infer_type(42), 'int') + self.assertEqual(self.parser._infer_type(-17), 'int') + self.assertEqual(self.parser._infer_type(0), 'int') + + def test_infer_type_float(self): + """Test type inference for float values""" + self.assertEqual(self.parser._infer_type(3.14), 'float') + self.assertEqual(self.parser._infer_type(-2.5), 'float') + self.assertEqual(self.parser._infer_type(0.0), 'float') + + def test_infer_type_string(self): + """Test type inference for string values""" + self.assertEqual(self.parser._infer_type("hello"), 'string') + self.assertEqual(self.parser._infer_type(""), 'string') + self.assertEqual(self.parser._infer_type("123"), 'string') + + def test_infer_type_array(self): + """Test type inference for array values""" + self.assertEqual(self.parser._infer_type([1, 2, 3]), 'array[int]') + self.assertEqual(self.parser._infer_type(["a", "b"]), 'array[string]') + self.assertEqual(self.parser._infer_type([]), 'array[string]') + self.assertEqual(self.parser._infer_type([True, False]), 'array[boolean]') + + def test_extract_parameters_flat(self): + """Test parameter extraction for flat structure""" + flat_values = { + "param1": "value1", + "param2": True, + "param3": 42 + } + + params = self.parser._extract_parameters(flat_values) + + self.assertEqual(len(params), 3) + self.assertIn("param1", params) + self.assertIn("param2", params) + self.assertIn("param3", params) + + self.assertEqual(params["param1"].type, "string") + self.assertEqual(params["param2"].type, "boolean") + self.assertEqual(params["param3"].type, "int") + + def test_extract_parameters_nested(self): + """Test parameter extraction for nested structure""" + params = self.parser._extract_parameters(self.values_yaml) + + self.assertIn("nested.param1", params) + self.assertIn("nested.deep.param3", params) + + nested_param = params["nested.param1"] + self.assertEqual(nested_param.name, "nested.param1") + self.assertEqual(nested_param.type, "boolean") + self.assertEqual(nested_param.nested_path, ["nested", "param1"]) + + deep_param = params["nested.deep.param3"] + self.assertEqual(deep_param.name, "nested.deep.param3") + self.assertEqual(deep_param.type, "string") + self.assertEqual(deep_param.nested_path, ["nested", "deep", "param3"]) + + def test_extract_parameters_empty(self): + """Test parameter extraction with empty values""" + params = self.parser._extract_parameters({}) + self.assertEqual(len(params), 0) + + @patch('os.path.exists') + def test_read_yaml_missing_file(self, mock_exists): + """Test reading non-existent YAML file""" + mock_exists.return_value = False + result = self.parser._read_yaml("nonexistent.yaml") + self.assertEqual(result, {}) + + @patch('builtins.open') + @patch('os.path.exists') + def test_read_yaml_empty_file(self, mock_exists, mock_file): + """Test reading empty YAML file""" + # Mock file existence + mock_exists.return_value = True + + # Mock empty file + mock_file_handle = mock_open(read_data="").return_value + mock_file.return_value = mock_file_handle + result = self.parser._read_yaml("empty.yaml") + self.assertEqual(result, {}) + + @patch('builtins.open') + @patch('os.path.exists') + def test_read_yaml_invalid_yaml(self, mock_exists, mock_file): + """Test reading invalid YAML file""" + # Mock file existence + mock_exists.return_value = True + + # Mock file read operation to return invalid YAML + mock_file_handle = mock_open(read_data="invalid: yaml: :").return_value + mock_file.return_value = mock_file_handle + + # Mock yaml.safe_load to raise YAMLError + with patch('yaml.safe_load', side_effect=yaml.YAMLError("Invalid YAML")): + with self.assertRaises(yaml.YAMLError): + self.parser._read_yaml("invalid.yaml") + + @patch('helm_parser.parser.HelmChartParser._read_yaml') + def test_parse_complete(self, mock_read_yaml): + """Test complete chart parsing""" + mock_read_yaml.side_effect = [ + self.chart_yaml, # Chart.yaml + self.values_yaml # values.yaml + ] + + chart_data = self.parser.parse() + + self.assertIsInstance(chart_data, ChartData) + self.assertEqual(chart_data.name, "test-chart") + self.assertEqual(chart_data.version, "1.0.0") + self.assertEqual(chart_data.description, "Test chart description") + self.assertEqual(len(chart_data.dependencies), 2) + self.assertTrue(len(chart_data.parameters) > 0) + + @patch('helm_parser.parser.HelmChartParser._read_yaml') + def test_parse_missing_chart_yaml(self, mock_read_yaml): + """Test parsing with missing Chart.yaml""" + mock_read_yaml.side_effect = [{}, {}] # Empty Chart.yaml and values.yaml + + with self.assertRaises(ValueError): + self.parser.parse() + + @patch('helm_parser.parser.HelmChartParser._read_yaml') + def test_parse_missing_values_yaml(self, mock_read_yaml): + """Test parsing with missing values.yaml""" + mock_read_yaml.side_effect = [ + self.chart_yaml, # Chart.yaml + {} # Empty values.yaml + ] + + chart_data = self.parser.parse() + self.assertEqual(len(chart_data.parameters), 0) + + def test_parameter_defaults(self): + """Test ChartParameter default values""" + param = ChartParameter(name="test", type="string") + self.assertIsNone(param.default_value) + self.assertIsNone(param.description) + self.assertFalse(param.required) + self.assertIsNone(param.nested_path) + + def test_chart_data_defaults(self): + """Test ChartData default values""" + data = ChartData( + name="test", + version="1.0.0", + description=None, + parameters={}, + dependencies=[] + ) + self.assertEqual(data.name, "test") + self.assertEqual(data.version, "1.0.0") + self.assertIsNone(data.description) + self.assertEqual(len(data.parameters), 0) + self.assertEqual(len(data.dependencies), 0) + +if __name__ == '__main__': + unittest.main() diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/schema_generator/__init__.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/schema_generator/__init__.py new file mode 100644 index 00000000000..cda307e3855 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/schema_generator/__init__.py @@ -0,0 +1,7 @@ +""" +Schema generator package +""" + +from .generator import SchemaGenerator + +__all__ = ['SchemaGenerator'] diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/schema_generator/generator.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/schema_generator/generator.py new file mode 100644 index 00000000000..1566902d58d --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/schema_generator/generator.py @@ -0,0 +1,142 @@ +""" +Schema generator module +""" + +import yaml +from typing import Dict, Any +import asyncio +from utils.logger import LoggerMixin +from helm_parser.parser import ChartData, ChartParameter +from ai_analyzer.analyzer import AIParameterAnalyzer + +class SchemaGenerator(LoggerMixin): + """Generator for WO schemas from Helm chart data""" + + def __init__(self, ai_analyzer: AIParameterAnalyzer): + """ + Initialize the schema generator. + + Args: + ai_analyzer: AI analyzer instance for parameter analysis + """ + super().__init__() + self.ai_analyzer = ai_analyzer + self._debug_info = {} # Store additional info for debugging + + async def generate(self, chart_data: ChartData, name: str, version: str) -> str: + """ + Generate WO schema from chart data. + + Args: + chart_data: Parsed chart data + name: Schema name + version: Schema version + + Returns: + YAML string containing generated schema + """ + schema = { + 'name': name, + 'version': version, + 'rules': { + 'configs': {} + } + } + + try: + self.logger.info("Using AI-based parameter analysis") + analysis_results = await self.ai_analyzer.analyze_parameters(chart_data) + + if not analysis_results: + raise ValueError("AI analyzer returned empty results") + + self.logger.info(f"AI analysis returned {len(analysis_results)} parameters") + + # Generate schema based on AI analysis + for param_name, param in chart_data.parameters.items(): + if param_name in analysis_results and analysis_results[param_name]['configurable']: + analysis = analysis_results[param_name] + param_schema = self._generate_parameter_schema(param, analysis) + if param_schema: + schema['rules']['configs'][param_name] = param_schema + self.logger.debug(f"Added AI-analyzed parameter: {param_name}") + # Store additional info for debugging + self._store_debug_info(param_name, param, analysis) + + # Log analysis statistics + stats = self.ai_analyzer.get_analysis_stats(analysis_results) + self.logger.info(f"AI Analysis Stats: {stats}") + + except Exception as e: + self.logger.error(f"Schema generation failed: {str(e)}", exc_info=True) + raise + + # Add metadata if not empty + if chart_data.description: + schema['description'] = chart_data.description + + # Convert to YAML + try: + # Log schema structure for debugging + self.logger.debug(f"Generated schema structure: {schema.keys()}") + self.logger.debug(f"Number of configs: {len(schema['rules']['configs'])}") + + yaml_str = yaml.dump(schema, sort_keys=False, allow_unicode=True) + + # Validate by trying to parse back + yaml.safe_load(yaml_str) + + return yaml_str + + except Exception as e: + self.logger.error(f"Error in schema generation: {str(e)}", exc_info=True) + raise + + def _generate_parameter_schema(self, param: ChartParameter, + analysis: Dict[str, Any]) -> Dict[str, Any]: + """ + Generate schema for a parameter using AI analysis results. + + Args: + param: Parameter data + analysis: AI analysis results + + Returns: + Dictionary containing parameter schema + """ + try: + # Generate schema with required fields + return { + 'type': param.type, + 'required': analysis['required'], + 'editableAt': [analysis['edit_level']], + 'editableBy': [analysis['managed_by']] + } + except Exception as e: + self.logger.error(f"Error generating parameter schema for {param.name}: {str(e)}") + return None + + def _store_debug_info(self, param_name: str, param: ChartParameter, + analysis: Dict[str, Any]) -> None: + """ + Store additional parameter information for debugging. + + Args: + param_name: Name of the parameter + param: Parameter data + analysis: AI analysis results + """ + self._debug_info[param_name] = { + 'defaultValue': param.default_value, + 'description': param.description, + 'required': analysis.get('required', []) + } + + def get_debug_info(self) -> Dict[str, Any]: + """ + Get stored debug information. + + Returns: + Dictionary containing debug information for parameters + """ + return self._debug_info diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/schema_generator/test_generator.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/schema_generator/test_generator.py new file mode 100644 index 00000000000..974d38c9a93 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/schema_generator/test_generator.py @@ -0,0 +1,226 @@ +import unittest +import os +import sys +import yaml +import logging +from unittest.mock import patch, Mock, AsyncMock +from typing import Dict, Any + +# Add src directory to Python path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))) +from schema_generator.generator import SchemaGenerator +from helm_parser.parser import ChartData, ChartParameter +from ai_analyzer.analyzer import AIParameterAnalyzer + +class TestSchemaGenerator(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + """Set up test environment""" + # Setup logging + logging.getLogger("SchemaGenerator").setLevel(logging.DEBUG) + + # Create mock AI analyzer + self.mock_analyzer = AsyncMock(spec=AIParameterAnalyzer) + self.generator = SchemaGenerator(ai_analyzer=self.mock_analyzer) + + # Sample chart data + self.chart_data = ChartData( + name="test-chart", + version="1.0.0", + description="Test description", + parameters={ + "param1": ChartParameter( + name="param1", + type="string", + default_value="value1", + description="Parameter 1", + ), + "param2": ChartParameter( + name="param2", + type="int", + default_value=42, + description="Parameter 2", + required=True + ) + }, + dependencies=[] + ) + + # Sample analysis results + self.analysis_results = { + "param1": { + "configurable": True, + "managed_by": "IT", + "edit_level": "factory", + "required": True + }, + "param2": { + "configurable": False, + "managed_by": "OT", + "edit_level": "line", + "required": False + } + } + + # Sample stats + self.stats = { + "total_parameters": 2, + "configurable": 1, + "required": 1, + "it_managed": 1, + "ot_managed": 1 + } + + async def test_generate_success(self): + """Test successful schema generation""" + # Setup mock analyzer + self.mock_analyzer.analyze_parameters.return_value = self.analysis_results + self.mock_analyzer.get_analysis_stats.return_value = self.stats + + with self.assertLogs("SchemaGenerator", level='INFO') as log: + schema_yaml = await self.generator.generate( + chart_data=self.chart_data, + name="test-schema", + version="1.0.0" + ) + + # Verify logging + log_text = "\n".join(log.output) + self.assertIn("Using AI-based parameter analysis", log_text) + self.assertIn(f"AI analysis returned {len(self.analysis_results)} parameters", log_text) + self.assertIn(f"AI Analysis Stats: {self.stats}", log_text) + + # Parse generated YAML + schema = yaml.safe_load(schema_yaml) + + # Verify basic structure + self.assertEqual(schema["name"], "test-schema") + self.assertEqual(schema["version"], "1.0.0") + self.assertEqual(schema["description"], "Test description") + + # Verify configs + configs = schema["rules"]["configs"] + self.assertIn("param1", configs) + self.assertNotIn("param2", configs) # Not configurable + + # Verify parameter schema + param1_schema = configs["param1"] + self.assertEqual(param1_schema["type"], "string") + self.assertEqual(param1_schema["required"], True) + self.assertEqual(param1_schema["editableAt"], ["factory"]) + self.assertEqual(param1_schema["editableBy"], ["IT"]) + + async def test_generate_empty_analysis(self): + """Test handling of empty analysis results""" + self.mock_analyzer.analyze_parameters.return_value = {} + + with self.assertLogs("SchemaGenerator", level='ERROR') as log: + with self.assertRaises(ValueError) as context: + await self.generator.generate( + chart_data=self.chart_data, + name="test-schema", + version="1.0.0" + ) + self.assertIn("AI analyzer returned empty results", str(context.exception)) + self.assertIn("Schema generation failed", log.output[0]) + + def test_generate_parameter_schema(self): + """Test parameter schema generation""" + param = ChartParameter( + name="test_param", + type="string", + default_value="test" + ) + + analysis = { + "configurable": True, + "managed_by": "IT", + "edit_level": "factory", + "required": True + } + + schema = self.generator._generate_parameter_schema(param, analysis) + + self.assertEqual(schema["type"], "string") + self.assertEqual(schema["required"], True) + self.assertEqual(schema["editableAt"], ["factory"]) + self.assertEqual(schema["editableBy"], ["IT"]) + + def test_store_debug_info(self): + """Test debug info storage""" + param = ChartParameter( + name="test_param", + type="string", + default_value="test", + description="Test parameter" + ) + + analysis = { + "configurable": True, + "required": True + } + + self.generator._store_debug_info("test_param", param, analysis) + debug_info = self.generator.get_debug_info() + + self.assertIn("test_param", debug_info) + self.assertEqual(debug_info["test_param"]["defaultValue"], "test") + self.assertEqual(debug_info["test_param"]["description"], "Test parameter") + self.assertEqual(debug_info["test_param"]["required"], True) + + async def test_generate_yaml_error(self): + """Test YAML generation error handling""" + # Mock analyzer to return valid results first + self.mock_analyzer.analyze_parameters.return_value = self.analysis_results + self.mock_analyzer.get_analysis_stats.return_value = self.stats + + # Mock yaml.dump to raise an error + with patch('yaml.dump') as mock_dump: + mock_dump.side_effect = yaml.YAMLError("Cannot serialize to YAML") + + with self.assertLogs("SchemaGenerator", level='ERROR') as log: + with self.assertRaises(Exception) as context: + await self.generator.generate( + chart_data=self.chart_data, + name="test-schema", + version="1.0.0" + ) + self.assertIn("Error in schema generation", log.output[-1]) + self.assertIn("Cannot serialize to YAML", str(context.exception)) + + + async def test_generate_analysis_error(self): + """Test handling of analyzer errors""" + error_msg = "Analysis failed" + self.mock_analyzer.analyze_parameters.side_effect = Exception(error_msg) + + with self.assertLogs("SchemaGenerator", level='ERROR') as log: + with self.assertRaises(Exception) as context: + await self.generator.generate( + chart_data=self.chart_data, + name="test-schema", + version="1.0.0" + ) + self.assertEqual(str(context.exception), error_msg) + self.assertIn("Schema generation failed", log.output[0]) + + async def test_generate_invalid_parameter_schema(self): + """Test handling of invalid parameter schema generation""" + self.mock_analyzer.analyze_parameters.return_value = { + "param1": { + "configurable": True, + # Missing required fields + } + } + + with self.assertLogs("SchemaGenerator", level='ERROR') as log: + schema_yaml = await self.generator.generate( + chart_data=self.chart_data, + name="test-schema", + version="1.0.0" + ) + schema = yaml.safe_load(schema_yaml) + self.assertNotIn("param1", schema["rules"]["configs"]) + self.assertIn("Error generating parameter schema", log.output[0]) + +if __name__ == '__main__': + unittest.main() diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/template_generator/__init__.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/template_generator/__init__.py new file mode 100644 index 00000000000..26421802b40 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/template_generator/__init__.py @@ -0,0 +1,7 @@ +""" +Template generator module initialization +""" + +from .generator import TemplateGenerator + +__all__ = ['TemplateGenerator'] diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/template_generator/generator.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/template_generator/generator.py new file mode 100644 index 00000000000..f828e2de204 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/template_generator/generator.py @@ -0,0 +1,153 @@ +""" +Template generator module for WO Artifact Generator +""" + +import yaml +from typing import Dict, Any, List, Optional +from utils.logger import LoggerMixin +from helm_parser.parser import ChartData + +class TemplateGenerator(LoggerMixin): + """Generator for WO solution templates from Helm chart data""" + + def generate(self, chart_data: ChartData, name: str, version: str, schema: Dict[str, Any] = None) -> str: + """ + Generate WO solution template from chart data and schema. + + Args: + chart_data: Parsed chart data + name: Schema name + version: Schema version + schema: Generated WO schema (optional) + + Returns: + YAML string containing generated solution template + """ + # Build template structure + template = { + 'schema': { + 'name': name, + 'version': version + } + } + + # Generate configs section using schema-defined parameters if available + if schema and 'rules' in schema and 'configs' in schema['rules']: + template['configs'] = self._generate_config_section_from_schema( + chart_data, + schema['rules']['configs'] + ) + elif chart_data.parameters: + # Fallback to all parameters if no schema + template['configs'] = self._generate_config_section(chart_data) + + # Add dependencies if present + if chart_data.dependencies: + template['dependencies'] = self._generate_dependencies_section(chart_data) + + # Return formatted YAML + return yaml.dump(template, sort_keys=False, allow_unicode=True) + + def _generate_config_section_from_schema(self, + chart_data: ChartData, + schema_configs: Dict[str, Any]) -> Dict[str, Any]: + """ + Generate the configs section using schema-defined parameters. + + Args: + chart_data: Parsed chart data + schema_configs: Schema configuration rules + + Returns: + Dictionary containing the configs section + """ + configs = {} + + # Process only required parameters from schema + for param_name, param_config in schema_configs.items(): + if param_config.get('required', False): # Only include if required=True + if '.' in param_name: + path = param_name.split('.') + self._set_nested_value(configs, path, self._generate_template_value(param_name)) + else: + configs[param_name] = self._generate_template_value(param_name) + + return configs + + def _generate_config_section(self, chart_data: ChartData) -> Dict[str, Any]: + """ + Generate the configs section from all chart parameters (fallback). + + Args: + chart_data: Parsed chart data + + Returns: + Dictionary containing the configs section + """ + configs = {} + + # Process each parameter + for param_name, param in chart_data.parameters.items(): + if param.nested_path: + # Handle nested parameters + self._set_nested_value(configs, param.nested_path, + self._generate_template_value(param_name)) + else: + # Handle top-level parameters + configs[param_name] = self._generate_template_value(param_name) + + return configs + + def _generate_template_value(self, param_name: str) -> str: + """ + Generate template value following Config Manager Templating Language. + + Args: + param_name: Parameter name + + Returns: + Template value string + """ + return f"${{{{$val({param_name})}}}}" + + def _set_nested_value( + self, + config_dict: Dict[str, Any], + path: List[str], + value: Any + ) -> None: + """ + Set a value in a nested dictionary structure. + + Args: + config_dict: Dictionary to modify + path: Path to the value location + value: Value to set + """ + current = config_dict + + # Create nested structure + for component in path[:-1]: + if component not in current: + current[component] = {} + current = current[component] + + # Set the final value + if path: + current[path[-1]] = value + + def _generate_dependencies_section(self, chart_data: ChartData) -> List[Dict[str, Any]]: + """ + Generate the dependencies section if chart has dependencies. + + Args: + chart_data: Parsed chart data + + Returns: + List of dependency configurations + """ + return [{ + 'solutionTemplateId': '/common/1.0.0', + 'configsToBeInjected': [], + 'solutionTemplateVersion': '2.x.x' + }] if chart_data.dependencies else [] diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/template_generator/test_generator.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/template_generator/test_generator.py new file mode 100644 index 00000000000..b22c4973108 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/template_generator/test_generator.py @@ -0,0 +1,238 @@ +import unittest +import os +import sys +import yaml +from typing import Dict, Any +from unittest.mock import Mock + +# Add src directory to Python path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))) +from template_generator.generator import TemplateGenerator +from helm_parser.parser import ChartData, ChartParameter + +class TestTemplateGenerator(unittest.TestCase): + def setUp(self): + """Set up test environment""" + self.generator = TemplateGenerator() + + # Sample chart data + self.chart_data = ChartData( + name="test-chart", + version="1.0.0", + description="Test description", + parameters={ + "simple": ChartParameter( + name="simple", + type="string", + default_value="value1" + ), + "nested.param": ChartParameter( + name="nested.param", + type="int", + default_value=42, + nested_path=["nested", "param"] + ), + "deep.nested.param": ChartParameter( + name="deep.nested.param", + type="boolean", + default_value=True, + nested_path=["deep", "nested", "param"] + ) + }, + dependencies=[{"name": "common"}] + ) + + # Sample schema + self.schema = { + "name": "test-schema", + "version": "1.0.0", + "rules": { + "configs": { + "simple": { + "type": "string", + "required": True + }, + "nested.param": { + "type": "int", + "required": False + }, + "deep.nested.param": { + "type": "boolean", + "required": True + } + } + } + } + + def test_generate_with_schema(self): + """Test template generation with schema""" + template_yaml = self.generator.generate( + chart_data=self.chart_data, + name="test-template", + version="1.0.0", + schema=self.schema + ) + + template = yaml.safe_load(template_yaml) + + # Verify basic structure + self.assertEqual(template["schema"]["name"], "test-template") + self.assertEqual(template["schema"]["version"], "1.0.0") + + # Verify configs - only required parameters should be included + configs = template["configs"] + self.assertIn("simple", configs) + self.assertNotIn("nested", configs) # not required + self.assertIn("deep", configs) + self.assertEqual(configs["simple"], "${{$val(simple)}}") + self.assertEqual(configs["deep"]["nested"]["param"], "${{$val(deep.nested.param)}}") + + def test_generate_without_schema(self): + """Test template generation without schema""" + template_yaml = self.generator.generate( + chart_data=self.chart_data, + name="test-template", + version="1.0.0" + ) + + template = yaml.safe_load(template_yaml) + + # All parameters should be included + configs = template["configs"] + self.assertIn("simple", configs) + self.assertIn("nested", configs) + self.assertIn("deep", configs) + self.assertEqual(configs["simple"], "${{$val(simple)}}") + self.assertEqual(configs["nested"]["param"], "${{$val(nested.param)}}") + self.assertEqual(configs["deep"]["nested"]["param"], "${{$val(deep.nested.param)}}") + + def test_generate_empty_chart(self): + """Test template generation with empty chart data""" + empty_chart = ChartData( + name="empty", + version="1.0.0", + description=None, + parameters={}, + dependencies=[] + ) + + template_yaml = self.generator.generate( + chart_data=empty_chart, + name="test-template", + version="1.0.0" + ) + + template = yaml.safe_load(template_yaml) + self.assertNotIn("configs", template) + self.assertNotIn("dependencies", template) + + def test_generate_with_dependencies(self): + """Test template generation with dependencies""" + template_yaml = self.generator.generate( + chart_data=self.chart_data, + name="test-template", + version="1.0.0" + ) + + template = yaml.safe_load(template_yaml) + + # Verify dependencies section + self.assertIn("dependencies", template) + dependencies = template["dependencies"] + self.assertEqual(len(dependencies), 1) + self.assertEqual(dependencies[0]["solutionTemplateId"], "/common/1.0.0") + self.assertEqual(dependencies[0]["solutionTemplateVersion"], "2.x.x") + self.assertEqual(dependencies[0]["configsToBeInjected"], []) + + def test_generate_without_dependencies(self): + """Test template generation without dependencies""" + chart_data = ChartData( + name="no-deps", + version="1.0.0", + description=None, + parameters=self.chart_data.parameters, + dependencies=[] + ) + + template_yaml = self.generator.generate( + chart_data=chart_data, + name="test-template", + version="1.0.0" + ) + + template = yaml.safe_load(template_yaml) + self.assertNotIn("dependencies", template) + + def test_generate_template_value(self): + """Test template value generation""" + value = self.generator._generate_template_value("test.param") + self.assertEqual(value, "${{$val(test.param)}}") + + def test_set_nested_value(self): + """Test nested value setting""" + config = {} + + # Test simple path + self.generator._set_nested_value(config, ["simple"], "value1") + self.assertEqual(config["simple"], "value1") + + # Test nested path + self.generator._set_nested_value(config, ["nested", "param"], "value2") + self.assertEqual(config["nested"]["param"], "value2") + + # Test deep nesting + self.generator._set_nested_value(config, ["a", "b", "c", "d"], "value3") + self.assertEqual(config["a"]["b"]["c"]["d"], "value3") + + # Test empty path + self.generator._set_nested_value(config, [], "value4") + self.assertEqual(config, config) # Should not change + + def test_generate_config_section_from_schema(self): + """Test config generation from schema""" + configs = self.generator._generate_config_section_from_schema( + self.chart_data, + self.schema["rules"]["configs"] + ) + + # Only required parameters should be included + self.assertIn("simple", configs) + self.assertNotIn("nested", configs) + self.assertIn("deep", configs) + + # Verify template values + self.assertEqual(configs["simple"], "${{$val(simple)}}") + self.assertEqual(configs["deep"]["nested"]["param"], "${{$val(deep.nested.param)}}") + + def test_generate_config_section(self): + """Test config generation without schema""" + configs = self.generator._generate_config_section(self.chart_data) + + # All parameters should be included + self.assertIn("simple", configs) + self.assertIn("nested", configs) + self.assertIn("deep", configs) + + # Verify nested structures + self.assertEqual(configs["nested"]["param"], "${{$val(nested.param)}}") + self.assertEqual(configs["deep"]["nested"]["param"], "${{$val(deep.nested.param)}}") + + def test_yaml_generation(self): + """Test YAML generation formatting""" + template_yaml = self.generator.generate( + chart_data=self.chart_data, + name="test-template", + version="1.0.0" + ) + + # Verify it's valid YAML + template = yaml.safe_load(template_yaml) + self.assertIsInstance(template, dict) + + # Verify it can be dumped back + dumped = yaml.dump(template) + self.assertIsInstance(dumped, str) + self.assertGreater(len(dumped), 0) + +if __name__ == '__main__': + unittest.main() diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/test_wo_gen.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/test_wo_gen.py new file mode 100644 index 00000000000..e9cfc0bad21 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/test_wo_gen.py @@ -0,0 +1,156 @@ +# src/test_wo_gen.py +import unittest +import asyncio +import os +import shutil +from unittest.mock import patch, Mock, AsyncMock +from argparse import Namespace +import sys + +# Add src directory to Python path +sys.path.append(os.path.abspath(os.path.dirname(__file__))) +from wo_gen import parse_args, ensure_output_dir, main + +class TestWOGen(unittest.TestCase): + def setUp(self): + """Set up test environment""" + self.test_dir = "test_output" + self.test_args = { + "chart_path": "./test_chart", + "output_dir": self.test_dir, + "schema_name": "test-schema", + "schema_version": "1.0.0", + "ai_endpoint": "https://test.openai.azure.com", + "ai_key": "test-key", + "ai_model": "gpt-4", + "verbose": False, + "prompt": None + } + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + """Clean up after tests""" + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + self.loop.close() + + def test_parse_args_required(self): + """Test parsing of required arguments""" + with patch('sys.argv', [ + 'wo_gen.py', + './test_chart', + '--schema-name', 'test-schema', + '--schema-version', '1.0.0', + '--ai-endpoint', 'https://test.openai.azure.com', + '--ai-key', 'test-key', + '--ai-model', 'gpt-4' + ]): + args = parse_args() + self.assertEqual(args.chart_path, './test_chart') + self.assertEqual(args.schema_name, 'test-schema') + self.assertEqual(args.schema_version, '1.0.0') + self.assertEqual(args.ai_endpoint, 'https://test.openai.azure.com') + self.assertEqual(args.ai_key, 'test-key') + self.assertEqual(args.ai_model, 'gpt-4') + + def test_parse_args_defaults(self): + """Test default argument values""" + with patch('sys.argv', [ + 'wo_gen.py', + './test_chart', + '--schema-name', 'test-schema', + '--schema-version', '1.0.0', + '--ai-endpoint', 'https://test.openai.azure.com', + '--ai-key', 'test-key', + '--ai-model', 'gpt-4' + ]): + args = parse_args() + self.assertEqual(args.output_dir, './output') + self.assertFalse(args.verbose) + self.assertIsNone(args.prompt) + + def test_ensure_output_dir_new(self): + """Test creating new output directory""" + test_dir = os.path.join(self.test_dir, "new_dir") + self.assertFalse(os.path.exists(test_dir)) + ensure_output_dir(test_dir) + self.assertTrue(os.path.exists(test_dir)) + + def test_ensure_output_dir_existing(self): + """Test with existing output directory""" + test_dir = os.path.join(self.test_dir, "existing_dir") + os.makedirs(test_dir) + ensure_output_dir(test_dir) + self.assertTrue(os.path.exists(test_dir)) + + @patch('wo_gen.HierarchyManager') + @patch('wo_gen.PromptManager') + @patch('wo_gen.AIParameterAnalyzer') + @patch('wo_gen.HelmChartParser') + @patch('wo_gen.SchemaGenerator') + @patch('wo_gen.TemplateGenerator') + def test_main_workflow(self, mock_template_gen, mock_schema_gen, + mock_parser, mock_analyzer, mock_prompt, + mock_hierarchy): + """Test main workflow with mocked components""" + # Setup mocks + mock_hierarchy_instance = Mock() + mock_hierarchy_instance.get_hierarchy_levels.return_value = ['factory', 'line'] + mock_hierarchy.return_value = mock_hierarchy_instance + + mock_prompt_instance = Mock() + mock_prompt_instance.get_prompt.return_value = "test prompt" + mock_prompt.return_value = mock_prompt_instance + + mock_parser_instance = Mock() + mock_parser_instance.parse.return_value = {"test": "data"} + mock_parser.return_value = mock_parser_instance + + mock_analyzer_instance = Mock() + mock_analyzer.return_value = mock_analyzer_instance + + # Create async mock for schema generator + mock_schema_gen_instance = AsyncMock() + mock_schema_gen_instance.generate.return_value = "test schema" + mock_schema_gen.return_value = mock_schema_gen_instance + + mock_template_gen_instance = Mock() + mock_template_gen_instance.generate.return_value = "test template" + mock_template_gen.return_value = mock_template_gen_instance + + # Create test arguments + test_args = Namespace(**self.test_args) + + # Run main with mocked arguments + with patch('wo_gen.parse_args', return_value=test_args): + self.loop.run_until_complete(main()) + + # Verify workflow + mock_hierarchy_instance.update_hierarchy_levels.assert_called_once() + mock_prompt_instance.get_prompt.assert_called_once() + mock_parser_instance.parse.assert_called_once() + mock_schema_gen_instance.generate.assert_called_once() + mock_template_gen_instance.generate.assert_called_once() + + # Verify output files + schema_file = os.path.join(self.test_dir, "test-schema-schema.yaml") + template_file = os.path.join(self.test_dir, "test-schema-template.yaml") + self.assertTrue(os.path.exists(schema_file)) + self.assertTrue(os.path.exists(template_file)) + + @patch('wo_gen.HierarchyManager') + @patch('wo_gen.PromptManager') + def test_main_error_handling(self, mock_prompt, mock_hierarchy): + """Test error handling in main workflow""" + mock_prompt_instance = Mock() + mock_prompt_instance.get_prompt.side_effect = Exception("Test error") + mock_prompt.return_value = mock_prompt_instance + + test_args = Namespace(**self.test_args) + with patch('wo_gen.parse_args', return_value=test_args): + with self.assertRaises(Exception): + self.loop.run_until_complete(main()) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/utils/__init__.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/utils/__init__.py new file mode 100644 index 00000000000..2ed47434430 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/utils/__init__.py @@ -0,0 +1,7 @@ +""" +Utilities module initialization +""" + +from .logger import setup_logger, LoggerMixin + +__all__ = ['setup_logger', 'LoggerMixin'] diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/utils/hierarchy_manager.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/utils/hierarchy_manager.py new file mode 100644 index 00000000000..ed47233cc4f --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/utils/hierarchy_manager.py @@ -0,0 +1,96 @@ +""" +Hierarchy level manager for WO Artifact Generator +""" + +import json +import os +import subprocess +from datetime import datetime +import logging +from typing import List, Optional + +class HierarchyManager: + """Manages hierarchy levels from Azure Edge contexts""" + + def __init__(self): + """Initialize the hierarchy manager""" + self.hierarchy_file = os.path.join("config", "hierarchy_levels.json") + self.subscription_id = "973d15c6-6c57-447e-b9c6-6d79b5b784ab" + self.api_version = "2025-01-01-preview" + + # Create config directory if it doesn't exist + os.makedirs("config", exist_ok=True) + + def get_hierarchy_levels(self) -> List[str]: + """ + Get hierarchy levels from stored file. + Falls back to ['factory', 'line'] if file doesn't exist. + + Returns: + List of hierarchy level names + """ + try: + if os.path.exists(self.hierarchy_file): + with open(self.hierarchy_file, 'r') as f: + data = json.load(f) + levels = data.get('levels', []) + # Extract level names from response format + if isinstance(levels, list): + if levels and isinstance(levels[0], dict): + # Extract name field if levels are objects + return [level.get('name', '').lower() for level in levels if level.get('name')] + else: + # Use level strings directly + return [str(level).lower() for level in levels] + except Exception as e: + logging.warning(f"Failed to read hierarchy levels: {e}") + + return ['factory', 'line'] + + def update_hierarchy_levels(self) -> None: + """ + Update hierarchy levels by querying Azure Edge contexts. + Stores results in hierarchy_file. + """ + try: + # Get contexts + context_cmd = ( + f"az rest --method get " + f"--url https://management.azure.com/subscriptions/{self.subscription_id}" + f"/providers/microsoft.edge/contexts?api-version={self.api_version}" + ) + + # Run context command + context_result = subprocess.run( + context_cmd, + shell=True, + capture_output=True, + text=True + ) + + if context_result.returncode != 0: + raise Exception(f"Context command failed: {context_result.stderr}") + + # Parse context JSON + context_json = json.loads(context_result.stdout) + + # Get hierarchies from first context + if not context_json.get('value'): + raise Exception("No contexts found in response") + + hierarchies = context_json['value'][0]['properties'].get('hierarchies', []) + + # Store hierarchies with timestamp + data = { + 'levels': hierarchies, + 'last_updated': datetime.utcnow().isoformat() + } + + with open(self.hierarchy_file, 'w') as f: + json.dump(data, f, indent=2) + + logging.info(f"Updated hierarchy levels: {hierarchies}") + + except Exception as e: + logging.error(f"Failed to update hierarchy levels: {e}") + # Don't update file on error to preserve last good state diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/utils/logger.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/utils/logger.py new file mode 100644 index 00000000000..3bbabffac71 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/utils/logger.py @@ -0,0 +1,62 @@ +""" +Logger utility for WO Artifact Generator +""" + +import logging +import sys +from typing import Any, Optional + +def setup_logger(verbose: bool = False, name: str = "__main__") -> logging.Logger: + """ + Set up and configure logger instance. + + Args: + verbose: Whether to enable verbose logging + name: Name for the logger instance + + Returns: + Configured logger instance + """ + logger = logging.getLogger(name) + + if not logger.handlers: + # Configure handler + handler = logging.StreamHandler(sys.stdout) + + # Configure formatter + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + handler.setFormatter(formatter) + + # Remove any existing handlers + logger.handlers.clear() + + # Add handler to logger + logger.addHandler(handler) + + # Prevent propagation to root logger + logger.propagate = False + + # Set level based on verbose flag + logger.setLevel(logging.DEBUG if verbose else logging.INFO) + + return logger + +class LoggerMixin: + """ + Mixin class to add logging capabilities to any class. + """ + + def __init__(self, verbose: bool = False, *args: Any, **kwargs: Any) -> None: + """ + Initialize logger mixin. + + Args: + verbose: Whether to enable verbose logging + *args: Additional positional arguments + **kwargs: Additional keyword arguments + """ + super().__init__(*args, **kwargs) + self.logger = setup_logger(verbose=verbose, name=self.__class__.__name__) diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/utils/prompt_manager.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/utils/prompt_manager.py new file mode 100644 index 00000000000..3e1e85cd284 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/utils/prompt_manager.py @@ -0,0 +1,72 @@ +""" +Manages loading and validation of GPT prompts +""" +import os +from typing import Optional +from utils.logger import LoggerMixin + +class PromptManager(LoggerMixin): + """ + Manages loading and validation of analysis prompts. + Supports both default and custom prompts. + """ + + def __init__(self, prompt_path: Optional[str] = None): + """ + Initialize prompt manager. + + Args: + prompt_path: Optional path to custom prompt file + """ + super().__init__() + self.root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + self.prompts_dir = os.path.join(self.root_dir, 'prompts') + self.default_prompt_path = os.path.join(self.prompts_dir, 'default', 'prompt.txt') + self.custom_prompt_path = prompt_path + + def get_prompt(self) -> str: + """ + Get prompt content, either from custom file or default. + + Returns: + Prompt content as string + """ + try: + if self.custom_prompt_path: + full_path = ( + self.custom_prompt_path + if os.path.isabs(self.custom_prompt_path) + else os.path.join(self.prompts_dir, self.custom_prompt_path) + ) + + if os.path.exists(full_path): + with open(full_path, 'r') as f: + self.logger.info(f"Using custom prompt from: {full_path}") + return f.read() + else: + self.logger.warning( + f"Custom prompt {full_path} not found, using default" + ) + + # Fallback to default prompt + return self._get_default_prompt() + + except Exception as e: + self.logger.error(f"Error loading prompt: {str(e)}") + return self._get_default_prompt() + + def _get_default_prompt(self) -> str: + """ + Get default prompt content. + + Returns: + Default prompt content + """ + try: + with open(self.default_prompt_path, 'r') as f: + content = f.read() + self.logger.info("Using default prompt") + return content + except Exception as e: + self.logger.error(f"Failed to load default prompt: {str(e)}") + raise RuntimeError("Could not load any valid prompt") diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/wo_gen.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/wo_gen.py new file mode 100644 index 00000000000..7877eed80fc --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/wo_gen.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +""" +Workload Orchestration Template Generator +""" +from argparse import ArgumentParser +from typing import Dict, Any, Optional +import json +import os +import asyncio +import yaml +from helm_parser.parser import HelmChartParser +from schema_generator.generator import SchemaGenerator +from template_generator.generator import TemplateGenerator +from utils.prompt_manager import PromptManager +from utils.hierarchy_manager import HierarchyManager +from ai_analyzer.analyzer import AIParameterAnalyzer + +def parse_args(): + """Parse command line arguments""" + parser = ArgumentParser(description="Generate WO schema and template from Helm chart") + + parser.add_argument( + 'chart_path', + help='Path to Helm chart' + ) + + parser.add_argument( + '--output-dir', '-o', + default='./output', + help='Output directory for generated files' + ) + + parser.add_argument( + '--schema-name', + required=True, + help='Name for generated schema' + ) + + parser.add_argument( + '--schema-version', + required=True, + help='Version for generated schema' + ) + + parser.add_argument( + '--ai-endpoint', + required=True, + help='Azure OpenAI endpoint URL' + ) + + parser.add_argument( + '--ai-key', + required=True, + help='Azure OpenAI API key' + ) + + parser.add_argument( + '--ai-model', + required=True, + help='Azure OpenAI model deployment name' + ) + + parser.add_argument( + '--verbose', '-v', + action='store_true', + help='Enable verbose logging' + ) + + parser.add_argument( + '--prompt', + help='Path to custom prompt file relative to prompts directory', + default=None + ) + + return parser.parse_args() + +def ensure_output_dir(path: str): + """Ensure output directory exists""" + if not os.path.exists(path): + os.makedirs(path) + +async def main(): + """Main entry point""" + args = parse_args() + + # Update hierarchy levels + hierarchy_manager = HierarchyManager() + hierarchy_manager.update_hierarchy_levels() # Update on each run + hierarchy_levels = hierarchy_manager.get_hierarchy_levels() + + # Load custom prompt if specified + prompt_manager = PromptManager(args.prompt) + custom_prompt = prompt_manager.get_prompt() + + # Initialize AI analyzer + ai_analyzer = AIParameterAnalyzer( + endpoint=args.ai_endpoint, + api_key=args.ai_key, + deployment=args.ai_model, + custom_prompt=custom_prompt, + hierarchy_levels=hierarchy_levels + ) + + # Parse Helm chart + parser = HelmChartParser(args.chart_path) + chart_data = parser.parse() + + # Generate schema + schema_generator = SchemaGenerator(ai_analyzer=ai_analyzer) + schema = await schema_generator.generate( + chart_data=chart_data, + name=args.schema_name, + version=args.schema_version + ) + + # Create output files + ensure_output_dir(args.output_dir) + + schema_file = os.path.join(args.output_dir, f"{args.schema_name}-schema.yaml") + print(f"Schema saved to {schema_file}") + with open(schema_file, 'w') as f: + f.write(schema) + + # Parse schema from YAML + try: + schema_dict = yaml.safe_load(schema) + except Exception as e: + print(f"Warning: Failed to parse schema as YAML, template generation may be incomplete: {e}") + schema_dict = None + + # Generate template + template_generator = TemplateGenerator() + template = template_generator.generate( + chart_data=chart_data, + name=args.schema_name, + version=args.schema_version, + schema=schema_dict + ) + + # Save template + template_file = os.path.join(args.output_dir, f"{args.schema_name}-template.yaml") + print(f"Template saved to {template_file}") + with open(template_file, 'w') as f: + f.write(template) + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/wo_gen_combined.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/wo_gen_combined.py new file mode 100644 index 00000000000..66fcfea1e14 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/artifact/wo_artifact/src/wo_gen_combined.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +""" +Combined Workload Orchestration Template Generator +Simplified version that combines parsing, schema and template generation +""" +import os +import yaml +from dataclasses import dataclass +from typing import Dict, Any, Optional, List + +@dataclass +class ChartParameter: + """Data class representing a Helm chart parameter""" + name: str + type: str + default_value: Optional[Any] = None + description: Optional[str] = None + required: bool = False + nested_path: Optional[List[str]] = None + +@dataclass +class ChartData: + """Data class representing parsed Helm chart data""" + name: str + version: str + description: Optional[str] + parameters: Dict[str, ChartParameter] + dependencies: List[Dict[str, Any]] + +class HelmChartParser: + """Simplified Helm chart parser""" + + def __init__(self, chart_path: str): + self.chart_path = chart_path + self.values_file = os.path.join(chart_path, 'values.yaml') + self.chart_file = os.path.join(chart_path, 'Chart.yaml') + + def _read_yaml(self, file_path: str) -> Dict[str, Any]: + """Read and parse a YAML file""" + try: + if not os.path.exists(file_path): + print(f"Warning: File not found: {file_path}") + return {} + with open(file_path, 'r') as f: + return yaml.safe_load(f) or {} + except yaml.YAMLError as e: + raise ValueError(f"Error parsing YAML file {file_path}: {str(e)}") + + def _extract_parameters(self, data: Dict[str, Any], path: List[str] = None) -> Dict[str, ChartParameter]: + """Recursively extract parameters from values.yaml""" + if path is None: + path = [] + + parameters = {} + for key, value in data.items(): + current_path = path + [key] + + if isinstance(value, dict): + nested_params = self._extract_parameters(value, current_path) + parameters.update(nested_params) + else: + param_name = '.'.join(current_path) + param_type = self._infer_type(value) + parameters[param_name] = ChartParameter( + name=param_name, + type=param_type, + default_value=value, + nested_path=current_path, + required=True # Simplified: treat all as required + ) + return parameters + + def _infer_type(self, value: Any) -> str: + """Infer parameter type""" + if isinstance(value, bool): + return 'boolean' + elif isinstance(value, int): + return 'int' + elif isinstance(value, float): + return 'float' + elif isinstance(value, list): + return 'array[string]' + else: + return 'string' + + def parse(self) -> ChartData: + """Parse the Helm chart""" + chart_info = self._read_yaml(self.chart_file) + if not chart_info: + raise ValueError(f"Invalid or missing Chart.yaml in {self.chart_path}") + + values = self._read_yaml(self.values_file) + parameters = self._extract_parameters(values) + + return ChartData( + name=chart_info.get('name', ''), + version=chart_info.get('version', ''), + description=chart_info.get('description'), + parameters=parameters, + dependencies=chart_info.get('dependencies', []) + ) + +class SchemaGenerator: + """Simplified schema generator""" + + def generate(self, chart_data: ChartData, name: str, version: str) -> str: + """Generate schema from chart data""" + schema = { + 'name': name, + 'version': version, + 'rules': { + 'configs': {} + } + } + + for param_name, param in chart_data.parameters.items(): + schema['rules']['configs'][param_name] = { + 'type': param.type, + 'required': param.required, + 'editableAt': ['target'], # Simplified: always editable at target + 'editableBy': ['admin'] # Simplified: always editable by admin + } + + if chart_data.description: + schema['description'] = chart_data.description + + return yaml.dump(schema, sort_keys=False, allow_unicode=True) + +class TemplateGenerator: + """Simplified template generator""" + + def generate(self, chart_data: ChartData, name: str, version: str, schema: Dict[str, Any] = None) -> str: + """Generate solution template""" + template = { + 'schema': { + 'name': name, + 'version': version + } + } + + # Generate configs section + configs = {} + for param_name, param in chart_data.parameters.items(): + if param.nested_path: + self._set_nested_value(configs, param.nested_path, + self._generate_template_value(param_name)) + else: + configs[param_name] = self._generate_template_value(param_name) + + template['configs'] = configs + + if chart_data.dependencies: + template['dependencies'] = [{ + 'solutionTemplateId': '/common/1.0.0', + 'configsToBeInjected': [], + 'solutionTemplateVersion': '2.x.x' + }] + + return yaml.dump(template, sort_keys=False, allow_unicode=True) + + def _generate_template_value(self, param_name: str) -> str: + """Generate template value""" + return f"${{{{$val({param_name})}}}}" + + def _set_nested_value(self, config_dict: Dict[str, Any], path: List[str], value: Any) -> None: + """Set nested dictionary value""" + current = config_dict + for component in path[:-1]: + if component not in current: + current[component] = {} + current = current[component] + if path: + current[path[-1]] = value + +def main(): + """Main entry point""" + import argparse + + parser = argparse.ArgumentParser(description="Generate WO schema and template from Helm chart") + parser.add_argument('chart_path', help='Path to Helm chart') + parser.add_argument('--output-dir', '-o', default='./output', help='Output directory') + parser.add_argument('--schema-name', required=True, help='Schema name') + parser.add_argument('--schema-version', required=True, help='Schema version') + + args = parser.parse_args() + + # Ensure output directory exists + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + # Parse Helm chart + parser = HelmChartParser(args.chart_path) + chart_data = parser.parse() + + # Generate schema + schema_generator = SchemaGenerator() + schema = schema_generator.generate( + chart_data=chart_data, + name=args.schema_name, + version=args.schema_version + ) + + # Save schema + schema_file = os.path.join(args.output_dir, f"{args.schema_name}-schema.yaml") + print(f"Saving schema to {schema_file}") + with open(schema_file, 'w') as f: + f.write(schema) + + # Parse schema for template generation + schema_dict = yaml.safe_load(schema) + + # Generate template + template_generator = TemplateGenerator() + template = template_generator.generate( + chart_data=chart_data, + name=args.schema_name, + version=args.schema_version, + schema=schema_dict + ) + + # Save template + template_file = os.path.join(args.output_dir, f"{args.schema_name}-template.yaml") + print(f"Saving template to {template_file}") + with open(template_file, 'w') as f: + f.write(template) + +if __name__ == '__main__': + main()