From 9d9df4fc1edc73a8fbef2e61059a742f96581b74 Mon Sep 17 00:00:00 2001 From: "Prekshith D J (Persistent Systems Inc)" Date: Fri, 17 Oct 2025 14:31:04 +0530 Subject: [PATCH 1/9] updated the network module --- infra/main.bicep | 342 +------------------------- infra/modules/virtualNetwork.bicep | 374 +++++++++++++++++++++++++++++ 2 files changed, 386 insertions(+), 330 deletions(-) create mode 100644 infra/modules/virtualNetwork.bicep diff --git a/infra/main.bicep b/infra/main.bicep index 08a3a109a..0c2451531 100644 --- a/infra/main.bicep +++ b/infra/main.bicep @@ -381,338 +381,20 @@ module userAssignedIdentity 'br/public:avm/res/managed-identity/user-assigned-id enableTelemetry: enableTelemetry } } - -// ========== Network Security Groups ========== // -// WAF best practices for virtual networks: https://learn.microsoft.com/en-us/azure/well-architected/service-guides/virtual-network -// WAF recommendations for networking and connectivity: https://learn.microsoft.com/en-us/azure/well-architected/security/networking -var networkSecurityGroupBackendResourceName = 'nsg-${solutionSuffix}-backend' -module networkSecurityGroupBackend 'br/public:avm/res/network/network-security-group:0.5.1' = if (enablePrivateNetworking) { - name: take('avm.res.network.network-security-group.backend.${networkSecurityGroupBackendResourceName}', 64) - params: { - name: networkSecurityGroupBackendResourceName - location: location - tags: tags - enableTelemetry: enableTelemetry - diagnosticSettings: enableMonitoring ? [{ workspaceResourceId: logAnalyticsWorkspaceResourceId }] : null - securityRules: [ - { - name: 'deny-hop-outbound' - properties: { - access: 'Deny' - destinationAddressPrefix: '*' - destinationPortRanges: [ - '22' - '3389' - ] - direction: 'Outbound' - priority: 200 - protocol: 'Tcp' - sourceAddressPrefix: 'VirtualNetwork' - sourcePortRange: '*' - } - } - ] - } -} - -var networkSecurityGroupBastionResourceName = 'nsg-${solutionSuffix}-bastion' -module networkSecurityGroupBastion 'br/public:avm/res/network/network-security-group:0.5.1' = if (enablePrivateNetworking) { - name: take('avm.res.network.network-security-group.bastion${networkSecurityGroupBastionResourceName}', 64) - params: { - name: networkSecurityGroupBastionResourceName - location: location - tags: tags - enableTelemetry: enableTelemetry - diagnosticSettings: enableMonitoring ? [{ workspaceResourceId: logAnalyticsWorkspaceResourceId }] : null - securityRules: [ - { - name: 'AllowHttpsInBound' - properties: { - protocol: 'Tcp' - sourcePortRange: '*' - sourceAddressPrefix: 'Internet' - destinationPortRange: '443' - destinationAddressPrefix: '*' - access: 'Allow' - priority: 100 - direction: 'Inbound' - } - } - { - name: 'AllowGatewayManagerInBound' - properties: { - protocol: 'Tcp' - sourcePortRange: '*' - sourceAddressPrefix: 'GatewayManager' - destinationPortRange: '443' - destinationAddressPrefix: '*' - access: 'Allow' - priority: 110 - direction: 'Inbound' - } - } - { - name: 'AllowLoadBalancerInBound' - properties: { - protocol: 'Tcp' - sourcePortRange: '*' - sourceAddressPrefix: 'AzureLoadBalancer' - destinationPortRange: '443' - destinationAddressPrefix: '*' - access: 'Allow' - priority: 120 - direction: 'Inbound' - } - } - { - name: 'AllowBastionHostCommunicationInBound' - properties: { - protocol: '*' - sourcePortRange: '*' - sourceAddressPrefix: 'VirtualNetwork' - destinationPortRanges: [ - '8080' - '5701' - ] - destinationAddressPrefix: 'VirtualNetwork' - access: 'Allow' - priority: 130 - direction: 'Inbound' - } - } - { - name: 'DenyAllInBound' - properties: { - protocol: '*' - sourcePortRange: '*' - sourceAddressPrefix: '*' - destinationPortRange: '*' - destinationAddressPrefix: '*' - access: 'Deny' - priority: 1000 - direction: 'Inbound' - } - } - { - name: 'AllowSshRdpOutBound' - properties: { - protocol: 'Tcp' - sourcePortRange: '*' - sourceAddressPrefix: '*' - destinationPortRanges: [ - '22' - '3389' - ] - destinationAddressPrefix: 'VirtualNetwork' - access: 'Allow' - priority: 100 - direction: 'Outbound' - } - } - { - name: 'AllowAzureCloudCommunicationOutBound' - properties: { - protocol: 'Tcp' - sourcePortRange: '*' - sourceAddressPrefix: '*' - destinationPortRange: '443' - destinationAddressPrefix: 'AzureCloud' - access: 'Allow' - priority: 110 - direction: 'Outbound' - } - } - { - name: 'AllowBastionHostCommunicationOutBound' - properties: { - protocol: '*' - sourcePortRange: '*' - sourceAddressPrefix: 'VirtualNetwork' - destinationPortRanges: [ - '8080' - '5701' - ] - destinationAddressPrefix: 'VirtualNetwork' - access: 'Allow' - priority: 120 - direction: 'Outbound' - } - } - { - name: 'AllowGetSessionInformationOutBound' - properties: { - protocol: '*' - sourcePortRange: '*' - sourceAddressPrefix: '*' - destinationAddressPrefix: 'Internet' - destinationPortRanges: [ - '80' - '443' - ] - access: 'Allow' - priority: 130 - direction: 'Outbound' - } - } - { - name: 'DenyAllOutBound' - properties: { - protocol: '*' - sourcePortRange: '*' - destinationPortRange: '*' - sourceAddressPrefix: '*' - destinationAddressPrefix: '*' - access: 'Deny' - priority: 1000 - direction: 'Outbound' - } - } - ] - } -} - -var networkSecurityGroupAdministrationResourceName = 'nsg-${solutionSuffix}-administration' -module networkSecurityGroupAdministration 'br/public:avm/res/network/network-security-group:0.5.1' = if (enablePrivateNetworking) { - name: take('avm.res.network.network-security-group.administration.${networkSecurityGroupAdministrationResourceName}', 64) - params: { - name: networkSecurityGroupAdministrationResourceName - location: location - tags: tags - enableTelemetry: enableTelemetry - diagnosticSettings: enableMonitoring ? [{ workspaceResourceId: logAnalyticsWorkspaceResourceId }] : null - securityRules: [ - { - name: 'deny-hop-outbound' - properties: { - access: 'Deny' - destinationAddressPrefix: '*' - destinationPortRanges: [ - '22' - '3389' - ] - direction: 'Outbound' - priority: 200 - protocol: 'Tcp' - sourceAddressPrefix: 'VirtualNetwork' - sourcePortRange: '*' - } - } - ] - } -} - -var networkSecurityGroupContainersResourceName = 'nsg-${solutionSuffix}-containers' -module networkSecurityGroupContainers 'br/public:avm/res/network/network-security-group:0.5.1' = if (enablePrivateNetworking) { - name: take('avm.res.network.network-security-group.containers.${networkSecurityGroupContainersResourceName}', 64) - params: { - name: networkSecurityGroupContainersResourceName - location: location - tags: tags - enableTelemetry: enableTelemetry - diagnosticSettings: enableMonitoring ? [{ workspaceResourceId: logAnalyticsWorkspaceResourceId }] : null - securityRules: [ - { - name: 'deny-hop-outbound' - properties: { - access: 'Deny' - destinationAddressPrefix: '*' - destinationPortRanges: [ - '22' - '3389' - ] - direction: 'Outbound' - priority: 200 - protocol: 'Tcp' - sourceAddressPrefix: 'VirtualNetwork' - sourcePortRange: '*' - } - } - ] - } -} - -var networkSecurityGroupWebsiteResourceName = 'nsg-${solutionSuffix}-website' -module networkSecurityGroupWebsite 'br/public:avm/res/network/network-security-group:0.5.1' = if (enablePrivateNetworking) { - name: take('avm.res.network.network-security-group.website.${networkSecurityGroupWebsiteResourceName}', 64) - params: { - name: networkSecurityGroupWebsiteResourceName - location: location - tags: tags - enableTelemetry: enableTelemetry - diagnosticSettings: enableMonitoring ? [{ workspaceResourceId: logAnalyticsWorkspaceResourceId }] : null - securityRules: [ - { - name: 'deny-hop-outbound' - properties: { - access: 'Deny' - destinationAddressPrefix: '*' - destinationPortRanges: [ - '22' - '3389' - ] - direction: 'Outbound' - priority: 200 - protocol: 'Tcp' - sourceAddressPrefix: 'VirtualNetwork' - sourcePortRange: '*' - } - } - ] - } -} - // ========== Virtual Network ========== // // WAF best practices for virtual networks: https://learn.microsoft.com/en-us/azure/well-architected/service-guides/virtual-network // WAF recommendations for networking and connectivity: https://learn.microsoft.com/en-us/azure/well-architected/security/networking var virtualNetworkResourceName = 'vnet-${solutionSuffix}' -module virtualNetwork 'br/public:avm/res/network/virtual-network:0.7.0' = if (enablePrivateNetworking) { - name: take('avm.res.network.virtual-network.${virtualNetworkResourceName}', 64) +module virtualNetwork 'modules/virtualNetwork.bicep' = if (enablePrivateNetworking) { + name: take('module.virtualNetwork.${solutionSuffix}', 64) params: { - name: virtualNetworkResourceName + name: 'vnet-${solutionSuffix}' location: location tags: tags enableTelemetry: enableTelemetry addressPrefixes: ['10.0.0.0/8'] - subnets: [ - { - name: 'backend' - addressPrefix: '10.0.0.0/27' - networkSecurityGroupResourceId: networkSecurityGroupBackend!.outputs.resourceId - } - { - name: 'administration' - addressPrefix: '10.0.0.32/27' - networkSecurityGroupResourceId: networkSecurityGroupAdministration!.outputs.resourceId - //natGatewayResourceId: natGateway.outputs.resourceId - } - { - // For Azure Bastion resources deployed on or after November 2, 2021, the minimum AzureBastionSubnet size is /26 or larger (/25, /24, etc.). - // https://learn.microsoft.com/en-us/azure/bastion/configuration-settings#subnet - name: 'AzureBastionSubnet' //This exact name is required for Azure Bastion - addressPrefix: '10.0.0.64/26' - networkSecurityGroupResourceId: networkSecurityGroupBastion!.outputs.resourceId - } - { - // If you use your own vnw, you need to provide a subnet that is dedicated exclusively to the Container App environment you deploy. This subnet isn't available to other services - // https://learn.microsoft.com/en-us/azure/container-apps/networking?tabs=workload-profiles-env%2Cazure-cli#custom-vnw-configuration - name: 'containers' - addressPrefix: '10.0.2.0/23' //subnet of size /23 is required for container app - delegation: 'Microsoft.App/environments' - networkSecurityGroupResourceId: networkSecurityGroupContainers!.outputs.resourceId - privateEndpointNetworkPolicies: 'Enabled' - privateLinkServiceNetworkPolicies: 'Enabled' - } - { - // If you use your own vnw, you need to provide a subnet that is dedicated exclusively to the App Environment you deploy. This subnet isn't available to other services - // https://learn.microsoft.com/en-us/azure/app-service/overview-vnet-integration#subnet-requirements - name: 'webserverfarm' - addressPrefix: '10.0.4.0/27' //When you're creating subnets in Azure portal as part of integrating with the virtual network, a minimum size of /27 is required - delegation: 'Microsoft.Web/serverfarms' - networkSecurityGroupResourceId: networkSecurityGroupWebsite!.outputs.resourceId - privateEndpointNetworkPolicies: 'Enabled' - privateLinkServiceNetworkPolicies: 'Enabled' - } - ] + logAnalyticsWorkspaceId: logAnalyticsWorkspaceResourceId + resourceSuffix: solutionSuffix } } @@ -961,7 +643,7 @@ module virtualMachine 'br/public:avm/res/compute/virtual-machine:0.17.0' = if (e ipConfigurations: [ { name: '${virtualMachineResourceName}-nic01-ipconfig01' - subnetResourceId: virtualNetwork!.outputs.subnetResourceIds[1] + subnetResourceId: virtualNetwork!.outputs.administrationSubnetResourceId diagnosticSettings: enableMonitoring //WAF aligned configuration for Monitoring ? [{ workspaceResourceId: logAnalyticsWorkspaceResourceId }] : null @@ -1280,7 +962,7 @@ module aiFoundryAiServices 'br:mcr.microsoft.com/bicep/avm/res/cognitive-service { name: 'pep-${aiFoundryAiServicesResourceName}' customNetworkInterfaceName: 'nic-${aiFoundryAiServicesResourceName}' - subnetResourceId: virtualNetwork!.outputs.subnetResourceIds[0] + subnetResourceId: virtualNetwork!.outputs.backendSubnetResourceId privateDnsZoneGroup: { privateDnsZoneGroupConfigs: [ { @@ -1393,7 +1075,7 @@ module cosmosDb 'br/public:avm/res/document-db/database-account:0.15.0' = { ] } service: 'Sql' - subnetResourceId: virtualNetwork!.outputs.subnetResourceIds[0] + subnetResourceId: virtualNetwork!.outputs.backendSubnetResourceId } ] : [] @@ -1438,7 +1120,7 @@ module containerAppEnvironment 'br/public:avm/res/app/managed-environment:0.11.2 // WAF aligned configuration for Private Networking publicNetworkAccess: 'Enabled' // Always enabling the publicNetworkAccess for Container App Environment internal: false // Must be false when publicNetworkAccess is'Enabled' - infrastructureSubnetResourceId: enablePrivateNetworking ? virtualNetwork.?outputs.?subnetResourceIds[3] : null + infrastructureSubnetResourceId: enablePrivateNetworking ? virtualNetwork.?outputs.?containerSubnetResourceId : null // WAF aligned configuration for Monitoring appLogsConfiguration: enableMonitoring ? { @@ -1826,7 +1508,7 @@ module webSite 'modules/web-sites.bicep' = { // WAF aligned configuration for Private Networking vnetRouteAllEnabled: enablePrivateNetworking ? true : false vnetImagePullEnabled: enablePrivateNetworking ? true : false - virtualNetworkSubnetId: enablePrivateNetworking ? virtualNetwork!.outputs.subnetResourceIds[4] : null + virtualNetworkSubnetId: enablePrivateNetworking ? virtualNetwork!.outputs.webserverfarmSubnetResourceId : null publicNetworkAccess: 'Enabled' // Always enabling the public network access for Web App e2eEncryptionEnabled: true } @@ -1884,7 +1566,7 @@ module avmStorageAccount 'br/public:avm/res/storage/storage-account:0.20.0' = { } ] } - subnetResourceId: virtualNetwork!.outputs.subnetResourceIds[0] + subnetResourceId: virtualNetwork!.outputs.backendSubnetResourceId service: 'blob' } ] @@ -2035,7 +1717,7 @@ module keyvault 'br/public:avm/res/key-vault/vault:0.12.1' = { privateDnsZoneGroupConfigs: [{ privateDnsZoneResourceId: avmPrivateDnsZones[dnsZoneIndex.keyVault]!.outputs.resourceId }] } service: 'vault' - subnetResourceId: virtualNetwork!.outputs.subnetResourceIds[0] + subnetResourceId: virtualNetwork!.outputs.backendSubnetResourceId } ] : [] diff --git a/infra/modules/virtualNetwork.bicep b/infra/modules/virtualNetwork.bicep new file mode 100644 index 000000000..b9b5f11b6 --- /dev/null +++ b/infra/modules/virtualNetwork.bicep @@ -0,0 +1,374 @@ +/****************************************************************************************************************************/ +// Networking - NSGs, VNET and Subnets. Each subnet has its own NSG +/****************************************************************************************************************************/ +@description('Name of the virtual network.') +param name string + +@description('Azure region to deploy resources.') +param location string = resourceGroup().location + +@description('Required. An Array of 1 or more IP Address Prefixes for the Virtual Network.') +param addressPrefixes array + +@description('An array of subnets to be created within the virtual network. Each subnet can have its own configuration and associated Network Security Group (NSG).') +param subnets subnetType[] = [ + + + { + name:'backend' + addressPrefixes: ['10.0.0.0/27'] + networkSecurityGroup: { + name: 'nsg-backend' + securityRules: [ + { + name: 'deny-hop-outbound' + properties: { + access: 'Deny' + destinationAddressPrefix: '*' + destinationPortRanges: [ + '22' + '3389' + ] + direction: 'Outbound' + priority: 200 + protocol: 'Tcp' + sourceAddressPrefix: 'VirtualNetwork' + sourcePortRange: '*' + } + } + ] + } + } + { + name: 'containers' + addressPrefixes: ['10.0.2.0/23'] + delegation: 'Microsoft.App/environments' + privateEndpointNetworkPolicies: 'Enabled' + privateLinkServiceNetworkPolicies: 'Enabled' + networkSecurityGroup: { + name: 'nsg-containers' + securityRules: [ + { + name: 'deny-hop-outbound' + properties: { + access: 'Deny' + destinationAddressPrefix: '*' + destinationPortRanges: [ + '22' + '3389' + ] + direction: 'Outbound' + priority: 200 + protocol: 'Tcp' + sourceAddressPrefix: 'VirtualNetwork' + sourcePortRange: '*' + } + } + ] + } + } + { + name: 'webserverfarm' + addressPrefixes: ['10.0.4.0/27'] + delegation: 'Microsoft.Web/serverfarms' + privateEndpointNetworkPolicies: 'Enabled' + privateLinkServiceNetworkPolicies: 'Enabled' + networkSecurityGroup: { + name: 'nsg-webserverfarm' + securityRules: [ + { + name: 'deny-hop-outbound' + properties: { + access: 'Deny' + destinationAddressPrefix: '*' + destinationPortRanges: [ + '22' + '3389' + ] + direction: 'Outbound' + priority: 200 + protocol: 'Tcp' + sourceAddressPrefix: 'VirtualNetwork' + sourcePortRange: '*' + } + } + ] + } + } + { + name: 'administration' + addressPrefixes: ['10.0.0.32/27'] + networkSecurityGroup: { + name: 'nsg-administration' + securityRules: [ + { + name: 'deny-hop-outbound' + properties: { + access: 'Deny' + destinationAddressPrefix: '*' + destinationPortRanges: [ + '22' + '3389' + ] + direction: 'Outbound' + priority: 200 + protocol: 'Tcp' + sourceAddressPrefix: 'VirtualNetwork' + sourcePortRange: '*' + } + } + ] + } + } + { + name: 'AzureBastionSubnet' // Required name for Azure Bastion + addressPrefixes: ['10.0.0.64/26'] + networkSecurityGroup: { + name: 'nsg-bastion' + securityRules: [ + { + name: 'AllowGatewayManager' + properties: { + access: 'Allow' + direction: 'Inbound' + priority: 2702 + protocol: '*' + sourcePortRange: '*' + destinationPortRange: '443' + sourceAddressPrefix: 'GatewayManager' + destinationAddressPrefix: '*' + } + } + { + name: 'AllowHttpsInBound' + properties: { + access: 'Allow' + direction: 'Inbound' + priority: 2703 + protocol: '*' + sourcePortRange: '*' + destinationPortRange: '443' + sourceAddressPrefix: 'Internet' + destinationAddressPrefix: '*' + } + } + { + name: 'AllowSshRdpOutbound' + properties: { + access: 'Allow' + direction: 'Outbound' + priority: 100 + protocol: '*' + sourcePortRange: '*' + destinationPortRanges: ['22', '3389'] + sourceAddressPrefix: '*' + destinationAddressPrefix: 'VirtualNetwork' + } + } + { + name: 'AllowAzureCloudOutbound' + properties: { + access: 'Allow' + direction: 'Outbound' + priority: 110 + protocol: 'Tcp' + sourcePortRange: '*' + destinationPortRange: '443' + sourceAddressPrefix: '*' + destinationAddressPrefix: 'AzureCloud' + } + } + ] + } + } +] + +@description('Optional. Tags to be applied to the resources.') +param tags object = {} + +@description('Optional. The resource ID of the Log Analytics Workspace to send diagnostic logs to.') +param logAnalyticsWorkspaceId string + +@description('Optional. Enable/Disable usage telemetry for module.') +param enableTelemetry bool = true + +@description('Required. Suffix for resource naming.') +param resourceSuffix string + +// VM Size Notes: +// 1 B-series VMs (like Standard_B2ms) do not support accelerated networking. +// 2 Pick a VM size that does support accelerated networking (the usual jump-box candidates): +// Standard_DS2_v2 (2 vCPU, 7 GiB RAM, Premium SSD) // The most broadly available (it’s a legacy SKU supported in virtually every region). +// Standard_D2s_v3 (2 vCPU, 8 GiB RAM, Premium SSD) // next most common +// Standard_D2s_v4 (2 vCPU, 8 GiB RAM, Premium SSD) // Newest, so fewer regions availabl + + +// Subnet Classless Inter-Doman Routing (CIDR) Sizing Reference Table (Best Practices) +// | CIDR | # of Addresses | # of /24s | Notes | +// |-----------|---------------|-----------|----------------------------------------| +// | /24 | 256 | 1 | Smallest recommended for Azure subnets | +// | /23 | 512 | 2 | Good for 1-2 workloads per subnet | +// | /22 | 1024 | 4 | Good for 2-4 workloads per subnet | +// | /21 | 2048 | 8 | | +// | /20 | 4096 | 16 | Used for default VNet in this solution | +// | /19 | 8192 | 32 | | +// | /18 | 16384 | 64 | | +// | /17 | 32768 | 128 | | +// | /16 | 65536 | 256 | | +// | /15 | 131072 | 512 | | +// | /14 | 262144 | 1024 | | +// | /13 | 524288 | 2048 | | +// | /12 | 1048576 | 4096 | | +// | /11 | 2097152 | 8192 | | +// | /10 | 4194304 | 16384 | | +// | /9 | 8388608 | 32768 | | +// | /8 | 16777216 | 65536 | | +// +// Best Practice Notes: +// - Use /24 as the minimum subnet size for Azure (smaller subnets are not supported for most services). +// - Plan for future growth: allocate larger address spaces (e.g., /20 or /21 for VNets) to allow for new subnets. +// - Avoid overlapping address spaces with on-premises or other VNets. +// - Use contiguous, non-overlapping ranges for subnets. +// - Document subnet usage and purpose in code comments. +// - For AVM modules, ensure only one delegation per subnet and leave delegations empty if not required. + +// 1. Create NSGs for subnets +// using AVM Network Security Group module +// https://github.com/Azure/bicep-registry-modules/tree/main/avm/res/network/network-security-group + +@batchSize(1) +module nsgs 'br/public:avm/res/network/network-security-group:0.5.1' = [ + for (subnet, i) in subnets: if (!empty(subnet.?networkSecurityGroup)) { + name: take('avm.res.network.network-security-group.${subnet.?networkSecurityGroup.name}.${resourceSuffix}', 64) + params: { + name: '${subnet.?networkSecurityGroup.name}-${resourceSuffix}' + location: location + securityRules: subnet.?networkSecurityGroup.securityRules + tags: tags + enableTelemetry: enableTelemetry + } + } +] + +// 2. Create VNet and subnets, with subnets associated with corresponding NSGs +// using AVM Virtual Network module +// https://github.com/Azure/bicep-registry-modules/tree/main/avm/res/network/virtual-network + +module virtualNetwork 'br/public:avm/res/network/virtual-network:0.7.0' = { + name: take('avm.res.network.virtual-network.${name}', 64) + params: { + name: name + location: location + addressPrefixes: addressPrefixes + subnets: [ + for (subnet, i) in subnets: { + name: subnet.name + addressPrefixes: subnet.?addressPrefixes + networkSecurityGroupResourceId: !empty(subnet.?networkSecurityGroup) ? nsgs[i]!.outputs.resourceId : null + privateEndpointNetworkPolicies: subnet.?privateEndpointNetworkPolicies + privateLinkServiceNetworkPolicies: subnet.?privateLinkServiceNetworkPolicies + delegation: subnet.?delegation + } + ] + diagnosticSettings: [ + { + name: 'vnetDiagnostics' + workspaceResourceId: logAnalyticsWorkspaceId + logCategoriesAndGroups: [ + { + categoryGroup: 'allLogs' + enabled: true + } + ] + metricCategories: [ + { + category: 'AllMetrics' + enabled: true + } + ] + } + ] + tags: tags + enableTelemetry: enableTelemetry + } +} + +output name string = virtualNetwork.outputs.name +output resourceId string = virtualNetwork.outputs.resourceId + +// combined output array that holds subnet details along with NSG information +output subnets subnetOutputType[] = [ + for (subnet, i) in subnets: { + name: subnet.name + resourceId: virtualNetwork.outputs.subnetResourceIds[i] + nsgName: !empty(subnet.?networkSecurityGroup) ? subnet.?networkSecurityGroup.name : null + nsgResourceId: !empty(subnet.?networkSecurityGroup) ? nsgs[i]!.outputs.resourceId : null + } +] + +// Dynamic outputs for individual subnets for backward compatibility +output backendSubnetResourceId string = contains(map(subnets, subnet => subnet.name), 'backend') ? virtualNetwork.outputs.subnetResourceIds[indexOf(map(subnets, subnet => subnet.name), 'backend')] : '' +output containerSubnetResourceId string = contains(map(subnets, subnet => subnet.name), 'containers') ? virtualNetwork.outputs.subnetResourceIds[indexOf(map(subnets, subnet => subnet.name), 'containers')] : '' +output administrationSubnetResourceId string = contains(map(subnets, subnet => subnet.name), 'administration') ? virtualNetwork.outputs.subnetResourceIds[indexOf(map(subnets, subnet => subnet.name), 'administration')] : '' +output webserverfarmSubnetResourceId string = contains(map(subnets, subnet => subnet.name), 'webserverfarm') ? virtualNetwork.outputs.subnetResourceIds[indexOf(map(subnets, subnet => subnet.name), 'webserverfarm')] : '' +output bastionSubnetResourceId string = contains(map(subnets, subnet => subnet.name), 'AzureBastionSubnet') ? virtualNetwork.outputs.subnetResourceIds[indexOf(map(subnets, subnet => subnet.name), 'AzureBastionSubnet')] : '' + +@export() +@description('Custom type definition for subnet resource information as output') +type subnetOutputType = { + @description('The name of the subnet.') + name: string + + @description('The resource ID of the subnet.') + resourceId: string + + @description('The name of the associated network security group, if any.') + nsgName: string? + + @description('The resource ID of the associated network security group, if any.') + nsgResourceId: string? +} + +@export() +@description('Custom type definition for subnet configuration') +type subnetType = { + @description('Required. The Name of the subnet resource.') + name: string + + @description('Required. Prefixes for the subnet.') // Required to ensure at least one prefix is provided + addressPrefixes: string[] + + @description('Optional. The delegation to enable on the subnet.') + delegation: string? + + @description('Optional. enable or disable apply network policies on private endpoint in the subnet.') + privateEndpointNetworkPolicies: ('Disabled' | 'Enabled' | 'NetworkSecurityGroupEnabled' | 'RouteTableEnabled')? + + @description('Optional. Enable or disable apply network policies on private link service in the subnet.') + privateLinkServiceNetworkPolicies: ('Disabled' | 'Enabled')? + + @description('Optional. Network Security Group configuration for the subnet.') + networkSecurityGroup: networkSecurityGroupType? + + @description('Optional. The resource ID of the route table to assign to the subnet.') + routeTableResourceId: string? + + @description('Optional. An array of service endpoint policies.') + serviceEndpointPolicies: object[]? + + @description('Optional. The service endpoints to enable on the subnet.') + serviceEndpoints: string[]? + + @description('Optional. Set this property to false to disable default outbound connectivity for all VMs in the subnet. This property can only be set at the time of subnet creation and cannot be updated for an existing subnet.') + defaultOutboundAccess: bool? +} + +@export() +@description('Custom type definition for network security group configuration') +type networkSecurityGroupType = { + @description('Required. The name of the network security group.') + name: string + + @description('Required. The security rules for the network security group.') + securityRules: object[] +} From 6523d89203c99327bf62692ef4a26c51173755bb Mon Sep 17 00:00:00 2001 From: Francia Riesco Date: Fri, 17 Oct 2025 11:35:01 -0400 Subject: [PATCH 2/9] Remove websocket streaming and unused kernel creation Deleted the websocket_streaming.py module and removed the create_kernel method from AppConfig, as the system now uses Azure AI Agent Project pattern. Added placeholder files for messages_af.py and utils_af.py to support future Azure Foundry integration. --- src/backend/common/config/app_config.py | 13 +- src/backend/common/models/messages_af.py | 0 src/backend/common/utils/utils_af.py | 0 .../common/utils/websocket_streaming.py | 214 ------------------ 4 files changed, 1 insertion(+), 226 deletions(-) create mode 100644 src/backend/common/models/messages_af.py create mode 100644 src/backend/common/utils/utils_af.py delete mode 100644 src/backend/common/utils/websocket_streaming.py diff --git a/src/backend/common/config/app_config.py b/src/backend/common/config/app_config.py index 5626752c1..92ece9734 100644 --- a/src/backend/common/config/app_config.py +++ b/src/backend/common/config/app_config.py @@ -7,7 +7,7 @@ from azure.cosmos import CosmosClient from azure.identity import DefaultAzureCredential, ManagedIdentityCredential from dotenv import load_dotenv -from semantic_kernel import Kernel + # Load environment variables from .env file load_dotenv() @@ -215,17 +215,6 @@ def get_cosmos_database_client(self): ) raise - def create_kernel(self): - """Creates a new Semantic Kernel instance. - - Returns: - A new Semantic Kernel instance - """ - # Create a new kernel instance without manually configuring OpenAI services - # The agents will be created using Azure AI Agent Project pattern instead - kernel = Kernel() - return kernel - def get_ai_project_client(self): """Create and return an AIProjectClient for Azure AI Foundry using from_connection_string. diff --git a/src/backend/common/models/messages_af.py b/src/backend/common/models/messages_af.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/common/utils/utils_af.py b/src/backend/common/utils/utils_af.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/common/utils/websocket_streaming.py b/src/backend/common/utils/websocket_streaming.py deleted file mode 100644 index 6a1baf519..000000000 --- a/src/backend/common/utils/websocket_streaming.py +++ /dev/null @@ -1,214 +0,0 @@ -""" -WebSocket endpoint for real-time plan execution streaming -This is a basic implementation that can be expanded based on your backend framework -""" - -import asyncio -import json -import logging -from typing import Dict, Set - -from fastapi import WebSocket, WebSocketDisconnect - -logger = logging.getLogger(__name__) - - -class WebSocketManager: - def __init__(self): - self.active_connections: Dict[str, WebSocket] = {} - self.plan_subscriptions: Dict[str, Set[str]] = {} # plan_id -> set of connection_ids - - async def connect(self, websocket: WebSocket, connection_id: str): - await websocket.accept() - self.active_connections[connection_id] = websocket - logger.info(f"WebSocket connection established: {connection_id}") - - def disconnect(self, connection_id: str): - if connection_id in self.active_connections: - del self.active_connections[connection_id] - - # Remove from all plan subscriptions - for plan_id, subscribers in self.plan_subscriptions.items(): - subscribers.discard(connection_id) - - logger.info(f"WebSocket connection closed: {connection_id}") - - async def send_personal_message(self, message: dict, connection_id: str): - if connection_id in self.active_connections: - websocket = self.active_connections[connection_id] - try: - await websocket.send_text(json.dumps(message)) - except Exception as e: - logger.error(f"Error sending message to {connection_id}: {e}") - self.disconnect(connection_id) - - async def broadcast_to_plan(self, message: dict, plan_id: str): - """Broadcast message to all subscribers of a specific plan""" - if plan_id not in self.plan_subscriptions: - return - - disconnected_connections = [] - - for connection_id in self.plan_subscriptions[plan_id].copy(): - if connection_id in self.active_connections: - websocket = self.active_connections[connection_id] - try: - await websocket.send_text(json.dumps(message)) - except Exception as e: - logger.error(f"Error broadcasting to {connection_id}: {e}") - disconnected_connections.append(connection_id) - - # Clean up failed connections - for connection_id in disconnected_connections: - self.disconnect(connection_id) - - def subscribe_to_plan(self, connection_id: str, plan_id: str): - if plan_id not in self.plan_subscriptions: - self.plan_subscriptions[plan_id] = set() - - self.plan_subscriptions[plan_id].add(connection_id) - logger.info(f"Connection {connection_id} subscribed to plan {plan_id}") - - def unsubscribe_from_plan(self, connection_id: str, plan_id: str): - if plan_id in self.plan_subscriptions: - self.plan_subscriptions[plan_id].discard(connection_id) - logger.info(f"Connection {connection_id} unsubscribed from plan {plan_id}") - - -# Global WebSocket manager instance -ws_manager = WebSocketManager() - - -# WebSocket endpoint -async def websocket_streaming_endpoint(websocket: WebSocket): - connection_id = f"conn_{id(websocket)}" - await ws_manager.connect(websocket, connection_id) - - try: - while True: - data = await websocket.receive_text() - message = json.loads(data) - - message_type = message.get("type") - - if message_type == "subscribe_plan": - plan_id = message.get("plan_id") - if plan_id: - ws_manager.subscribe_to_plan(connection_id, plan_id) - - # Send confirmation - await ws_manager.send_personal_message( - {"type": "subscription_confirmed", "plan_id": plan_id}, - connection_id, - ) - - elif message_type == "unsubscribe_plan": - plan_id = message.get("plan_id") - if plan_id: - ws_manager.unsubscribe_from_plan(connection_id, plan_id) - - else: - logger.warning(f"Unknown message type: {message_type}") - - except WebSocketDisconnect: - ws_manager.disconnect(connection_id) - except Exception as e: - logger.error(f"WebSocket error: {e}") - ws_manager.disconnect(connection_id) - - -# Example function to send plan updates (call this from your plan execution logic) -async def send_plan_update( - plan_id: str, - step_id: str = None, - agent_name: str = None, - content: str = None, - status: str = "in_progress", - message_type: str = "action", -): - """ - Send a streaming update for a specific plan - """ - message = { - "type": "plan_update", - "data": { - "plan_id": plan_id, - "step_id": step_id, - "agent_name": agent_name, - "content": content, - "status": status, - "message_type": message_type, - "timestamp": asyncio.get_event_loop().time(), - }, - } - - await ws_manager.broadcast_to_plan(message, plan_id) - - -# Example function to send agent messages -async def send_agent_message( - plan_id: str, agent_name: str, content: str, message_type: str = "thinking" -): - """ - Send a streaming message from an agent - """ - message = { - "type": "agent_message", - "data": { - "plan_id": plan_id, - "agent_name": agent_name, - "content": content, - "message_type": message_type, - "timestamp": asyncio.get_event_loop().time(), - }, - } - - await ws_manager.broadcast_to_plan(message, plan_id) - - -# Example function to send step updates -async def send_step_update( - plan_id: str, step_id: str, status: str, content: str = None -): - """ - Send a streaming update for a specific step - """ - message = { - "type": "step_update", - "data": { - "plan_id": plan_id, - "step_id": step_id, - "status": status, - "content": content, - "timestamp": asyncio.get_event_loop().time(), - }, - } - - await ws_manager.broadcast_to_plan(message, plan_id) - - -# Example integration with FastAPI -""" -from fastapi import FastAPI - -app = FastAPI() - -@app.websocket("/ws/streaming") -async def websocket_endpoint(websocket: WebSocket): - await websocket_streaming_endpoint(websocket) - -# Example usage in your plan execution logic: -async def execute_plan_step(plan_id: str, step_id: str): - # Send initial update - await send_step_update(plan_id, step_id, "in_progress", "Starting step execution...") - # Simulate some work - await asyncio.sleep(2) - # Send agent thinking message - await send_agent_message(plan_id, "Data Analyst", "Analyzing the requirements...", "thinking") - await asyncio.sleep(1) - # Send agent action message - await send_agent_message(plan_id, "Data Analyst", "Processing data and generating insights...", "action") - await asyncio.sleep(3) - # Send completion update - await send_step_update(plan_id, step_id, "completed", "Step completed successfully!") -""" From ea6d3b81b4c31bd35d09e90725d8f6341733d6b9 Mon Sep 17 00:00:00 2001 From: Francia Riesco Date: Mon, 20 Oct 2025 14:17:59 -0400 Subject: [PATCH 3/9] Add Agent Framework Pydantic models for migration Introduces new Pydantic-based data models in messages_af.py to replace KernelBaseModel from semantic_kernel. All original model names and structures are preserved to support incremental migration from the previous framework. --- src/backend/common/models/messages_af.py | 258 +++++++++++++++++++++++ 1 file changed, 258 insertions(+) diff --git a/src/backend/common/models/messages_af.py b/src/backend/common/models/messages_af.py index e69de29bb..cd46587dc 100644 --- a/src/backend/common/models/messages_af.py +++ b/src/backend/common/models/messages_af.py @@ -0,0 +1,258 @@ +""" +Agent Framework model equivalents for former Semantic Kernel-backed data models. + +This file replaces usage of KernelBaseModel from semantic_kernel with plain Pydantic BaseModel. +All original model names are preserved to enable incremental migration. +""" + +import uuid +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel, Field + + +# --------------------------------------------------------------------------- +# Enumerations +# --------------------------------------------------------------------------- + +class DataType(str, Enum): + session = "session" + plan = "plan" + step = "step" + agent_message = "agent_message" + team_config = "team_config" + user_current_team = "user_current_team" + m_plan = "m_plan" + m_plan_message = "m_plan_message" + + +class AgentType(str, Enum): + HUMAN = "Human_Agent" + HR = "Hr_Agent" + MARKETING = "Marketing_Agent" + PROCUREMENT = "Procurement_Agent" + PRODUCT = "Product_Agent" + GENERIC = "Generic_Agent" + TECH_SUPPORT = "Tech_Support_Agent" + GROUP_CHAT_MANAGER = "Group_Chat_Manager" + PLANNER = "Planner_Agent" + # Extend as needed + + +class StepStatus(str, Enum): + planned = "planned" + awaiting_feedback = "awaiting_feedback" + approved = "approved" + rejected = "rejected" + action_requested = "action_requested" + completed = "completed" + failed = "failed" + + +class PlanStatus(str, Enum): + in_progress = "in_progress" + completed = "completed" + failed = "failed" + canceled = "canceled" + approved = "approved" + created = "created" + + +class HumanFeedbackStatus(str, Enum): + requested = "requested" + accepted = "accepted" + rejected = "rejected" + + +class MessageRole(str, Enum): + system = "system" + user = "user" + assistant = "assistant" + function = "function" + + +class AgentMessageType(str, Enum): + # Removed trailing commas to avoid tuple enum values + HUMAN_AGENT = "Human_Agent" + AI_AGENT = "AI_Agent" + + +# --------------------------------------------------------------------------- +# Base Models +# --------------------------------------------------------------------------- + +class BaseDataModel(BaseModel): + """Base data model with common fields.""" + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + session_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + timestamp: Optional[datetime] = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class AgentMessage(BaseDataModel): + """Base class for messages sent between agents.""" + data_type: Literal[DataType.agent_message] = DataType.agent_message + plan_id: str + content: str + source: str + step_id: Optional[str] = None + + +class Session(BaseDataModel): + """Represents a user session.""" + data_type: Literal[DataType.session] = DataType.session + user_id: str + current_status: str + message_to_user: Optional[str] = None + + +class UserCurrentTeam(BaseDataModel): + """Represents the current team of a user.""" + data_type: Literal[DataType.user_current_team] = DataType.user_current_team + user_id: str + team_id: str + + +class Plan(BaseDataModel): + """Represents a plan containing multiple steps.""" + data_type: Literal[DataType.plan] = DataType.plan + plan_id: str + user_id: str + initial_goal: str + overall_status: PlanStatus = PlanStatus.in_progress + approved: bool = False + source: str = AgentType.PLANNER.value + m_plan: Optional[Dict[str, Any]] = None + summary: Optional[str] = None + team_id: Optional[str] = None + streaming_message: Optional[str] = None + human_clarification_request: Optional[str] = None + human_clarification_response: Optional[str] = None + + +class Step(BaseDataModel): + """Represents an individual step (task) within a plan.""" + data_type: Literal[DataType.step] = DataType.step + plan_id: str + user_id: str + action: str + agent: AgentType + status: StepStatus = StepStatus.planned + agent_reply: Optional[str] = None + human_feedback: Optional[str] = None + human_approval_status: Optional[HumanFeedbackStatus] = HumanFeedbackStatus.requested + updated_action: Optional[str] = None + + +class TeamSelectionRequest(BaseDataModel): + """Request model for team selection.""" + team_id: str + + +class TeamAgent(BaseModel): + """Represents an agent within a team.""" + input_key: str + type: str + name: str + deployment_name: str + system_message: str = "" + description: str = "" + icon: str + index_name: str = "" + use_rag: bool = False + use_mcp: bool = False + use_bing: bool = False + use_reasoning: bool = False + coding_tools: bool = False + + +class StartingTask(BaseModel): + """Represents a starting task for a team.""" + id: str + name: str + prompt: str + created: str + creator: str + logo: str + + +class TeamConfiguration(BaseDataModel): + """Represents a team configuration stored in the database.""" + team_id: str + data_type: Literal[DataType.team_config] = DataType.team_config + session_id: str # partition key + name: str + status: str + created: str + created_by: str + agents: List[TeamAgent] = Field(default_factory=list) + description: str = "" + logo: str = "" + plan: str = "" + starting_tasks: List[StartingTask] = Field(default_factory=list) + user_id: str # who uploaded this configuration + + +class PlanWithSteps(Plan): + """Plan model that includes the associated steps.""" + steps: List[Step] = Field(default_factory=list) + total_steps: int = 0 + planned: int = 0 + awaiting_feedback: int = 0 + approved: int = 0 + rejected: int = 0 + action_requested: int = 0 + completed: int = 0 + failed: int = 0 + + def update_step_counts(self) -> None: + """Update the counts of steps by their status.""" + status_counts = { + StepStatus.planned: 0, + StepStatus.awaiting_feedback: 0, + StepStatus.approved: 0, + StepStatus.rejected: 0, + StepStatus.action_requested: 0, + StepStatus.completed: 0, + StepStatus.failed: 0, + } + for step in self.steps: + status_counts[step.status] += 1 + + self.total_steps = len(self.steps) + self.planned = status_counts[StepStatus.planned] + self.awaiting_feedback = status_counts[StepStatus.awaiting_feedback] + self.approved = status_counts[StepStatus.approved] + self.rejected = status_counts[StepStatus.rejected] + self.action_requested = status_counts[StepStatus.action_requested] + self.completed = status_counts[StepStatus.completed] + self.failed = status_counts[StepStatus.failed] + + # Mark the plan as complete if the sum of completed and failed steps equals the total number of steps + if self.total_steps > 0 and (self.completed + self.failed) == self.total_steps: + self.overall_status = PlanStatus.completed + + +class InputTask(BaseModel): + """Message representing the initial input task from the user.""" + session_id: str + description: str + + +class UserLanguage(BaseModel): + language: str + + +class AgentMessageData(BaseDataModel): + """Represents a multi-plan agent message.""" + data_type: Literal[DataType.m_plan_message] = DataType.m_plan_message + plan_id: str + user_id: str + agent: str + m_plan_id: Optional[str] = None + agent_type: AgentMessageType = AgentMessageType.AI_AGENT + content: str + raw_data: str + steps: List[Any] = Field(default_factory=list) + next_steps: List[Any] = Field(default_factory=list From 6533e612ef29819ca48aaa000ca2c1aee0820897 Mon Sep 17 00:00:00 2001 From: Francia Riesco Date: Mon, 20 Oct 2025 14:18:04 -0400 Subject: [PATCH 4/9] Remove check_deployments utility script Deleted src/backend/common/utils/check_deployments.py as it is no longer needed. This script was used to check Azure AI Foundry model deployments and their statuses. --- src/backend/common/utils/check_deployments.py | 50 ------------------- 1 file changed, 50 deletions(-) delete mode 100644 src/backend/common/utils/check_deployments.py diff --git a/src/backend/common/utils/check_deployments.py b/src/backend/common/utils/check_deployments.py deleted file mode 100644 index 614c65ea4..000000000 --- a/src/backend/common/utils/check_deployments.py +++ /dev/null @@ -1,50 +0,0 @@ -import asyncio -import os -import sys -import traceback - -# Add the backend directory to the Python path -backend_path = os.path.join(os.path.dirname(__file__), "..", "..") -sys.path.insert(0, backend_path) - -try: - from v3.common.services.foundry_service import FoundryService -except ImportError as e: - print(f"❌ Import error: {e}") - sys.exit(1) - - -async def check_deployments(): - try: - print("πŸ” Checking Azure AI Foundry model deployments...") - foundry_service = FoundryService() - deployments = await foundry_service.list_model_deployments() - - # Filter successful deployments - successful_deployments = [ - d for d in deployments if d.get("status") == "Succeeded" - ] - - print( - f"βœ… Total deployments: {len(deployments)} (Successful: {len(successful_deployments)})" - ) - - available_models = [d.get("name", "").lower() for d in successful_deployments] - - # Check what we're looking for - required_models = ["gpt-4o", "o3", "gpt-4", "gpt-35-turbo"] - - print(f"\nπŸ” Checking required models: {required_models}") - for model in required_models: - if model.lower() in available_models: - print(f"βœ… {model} is available") - else: - print(f"❌ {model} is NOT available") - - except Exception as e: - print(f"❌ Error: {e}") - traceback.print_exc() - - -if __name__ == "__main__": - asyncio.run(check_deployments()) From bdb1bbf612c899c3e6f33a646f8726f708809dd2 Mon Sep 17 00:00:00 2001 From: Francia Riesco Date: Tue, 21 Oct 2025 10:03:09 -0400 Subject: [PATCH 5/9] Add agent framework API router and service skeletons Introduces the main FastAPI router for agent framework v3, including endpoints for team management, plan creation, approval, agent messaging, and configuration upload. Adds callback handlers for agent responses and streaming, global debug access, and service skeletons for agents, base API, foundry, MCP, and team management. These files establish the backend structure for multi-agent orchestration and extensible service integration. --- src/backend/af/__init__.py | 0 src/backend/af/api/router.py | 1338 +++++++++++++++++ src/backend/af/callbacks/__init__.py | 1 + src/backend/af/callbacks/global_debug.py | 14 + src/backend/af/callbacks/response_handlers.py | 206 +++ src/backend/af/common/services/__init__.py | 19 + .../af/common/services/agents_service.py | 121 ++ .../af/common/services/base_api_service.py | 114 ++ .../af/common/services/foundry_service.py | 116 ++ src/backend/af/common/services/mcp_service.py | 37 + .../af/common/services/plan_service.py | 254 ++++ .../af/common/services/team_service.py | 581 +++++++ src/backend/af/config/__init__.py | 1 + src/backend/af/config/agent_registry.py | 140 ++ src/backend/af/config/settings.py | 418 +++++ .../af/magentic_agents/foundry_agent.py | 294 ++++ .../magentic_agents/magentic_agent_factory.py | 196 +++ src/backend/af/magentic_agents/proxy_agent.py | 373 +++++ .../af/magentic_agents/reasoning_agent.py | 111 ++ .../af/magentic_agents/reasoning_search.py | 93 ++ src/backend/af/models/messages.py | 206 +++ src/backend/af/models/models.py | 35 + src/backend/af/models/orchestration_models.py | 56 + src/backend/af/orchestration/__init__.py | 1 + .../helper/plan_to_mplan_converter.py | 194 +++ .../orchestration/human_approval_manager.py | 370 +++++ .../af/orchestration/orchestration_manager.py | 180 +++ 27 files changed, 5469 insertions(+) create mode 100644 src/backend/af/__init__.py create mode 100644 src/backend/af/api/router.py create mode 100644 src/backend/af/callbacks/__init__.py create mode 100644 src/backend/af/callbacks/global_debug.py create mode 100644 src/backend/af/callbacks/response_handlers.py create mode 100644 src/backend/af/common/services/__init__.py create mode 100644 src/backend/af/common/services/agents_service.py create mode 100644 src/backend/af/common/services/base_api_service.py create mode 100644 src/backend/af/common/services/foundry_service.py create mode 100644 src/backend/af/common/services/mcp_service.py create mode 100644 src/backend/af/common/services/plan_service.py create mode 100644 src/backend/af/common/services/team_service.py create mode 100644 src/backend/af/config/__init__.py create mode 100644 src/backend/af/config/agent_registry.py create mode 100644 src/backend/af/config/settings.py create mode 100644 src/backend/af/magentic_agents/foundry_agent.py create mode 100644 src/backend/af/magentic_agents/magentic_agent_factory.py create mode 100644 src/backend/af/magentic_agents/proxy_agent.py create mode 100644 src/backend/af/magentic_agents/reasoning_agent.py create mode 100644 src/backend/af/magentic_agents/reasoning_search.py create mode 100644 src/backend/af/models/messages.py create mode 100644 src/backend/af/models/models.py create mode 100644 src/backend/af/models/orchestration_models.py create mode 100644 src/backend/af/orchestration/__init__.py create mode 100644 src/backend/af/orchestration/helper/plan_to_mplan_converter.py create mode 100644 src/backend/af/orchestration/human_approval_manager.py create mode 100644 src/backend/af/orchestration/orchestration_manager.py diff --git a/src/backend/af/__init__.py b/src/backend/af/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/af/api/router.py b/src/backend/af/api/router.py new file mode 100644 index 000000000..bf654444d --- /dev/null +++ b/src/backend/af/api/router.py @@ -0,0 +1,1338 @@ +import asyncio +import json +import logging +import uuid +from typing import Optional + +import af.models.messages as messages +from auth.auth_utils import get_authenticated_user_details +from common.database.database_factory import DatabaseFactory +from common.models.messages_kernel import ( + InputTask, + Plan, + PlanStatus, + TeamSelectionRequest, +) +from common.utils.event_utils import track_event_if_configured +from common.utils.utils_kernel import rai_success, rai_validate_team_config +from fastapi import ( + APIRouter, + BackgroundTasks, + File, + HTTPException, + Query, + Request, + UploadFile, + WebSocket, + WebSocketDisconnect, +) +from af.common.services.plan_service import PlanService +from af.common.services.team_service import TeamService +from af.config.settings import ( + connection_config, + orchestration_config, + team_config, +) +from af.orchestration.orchestration_manager import OrchestrationManager + +router = APIRouter() +logger = logging.getLogger(__name__) + +app_v3 = APIRouter( + prefix="/api/v3", + responses={404: {"description": "Not found"}}, +) + + +@app_v3.websocket("/socket/{process_id}") +async def start_comms( + websocket: WebSocket, process_id: str, user_id: str = Query(None) +): + """Web-Socket endpoint for real-time process status updates.""" + + # Always accept the WebSocket connection first + await websocket.accept() + + user_id = user_id or "00000000-0000-0000-0000-000000000000" + + # Add to the connection manager for backend updates + connection_config.add_connection( + process_id=process_id, connection=websocket, user_id=user_id + ) + track_event_if_configured( + "WebSocketConnectionAccepted", {"process_id": process_id, "user_id": user_id} + ) + + # Keep the connection open - FastAPI will close the connection if this returns + try: + # Keep the connection open - FastAPI will close the connection if this returns + while True: + # no expectation that we will receive anything from the client but this keeps + # the connection open and does not take cpu cycle + try: + message = await websocket.receive_text() + logging.debug(f"Received WebSocket message from {user_id}: {message}") + except asyncio.TimeoutError: + pass + except WebSocketDisconnect: + track_event_if_configured( + "WebSocketDisconnect", + {"process_id": process_id, "user_id": user_id}, + ) + logging.info(f"Client disconnected from batch {process_id}") + break + except Exception as e: + # Fixed logging syntax - removed the error= parameter + logging.error(f"Error in WebSocket connection: {str(e)}") + finally: + # Always clean up the connection + await connection_config.close_connection(process_id=process_id) + + +@app_v3.get("/init_team") +async def init_team( + request: Request, + team_switched: bool = Query(False), +): # add team_switched: bool parameter + """Initialize the user's current team of agents""" + + # Need to store this user state in cosmos db, retrieve it here, and initialize the team + # current in-memory store is in team_config from settings.py + # For now I will set the initial install team ids as 00000000-0000-0000-0000-000000000001 (HR), + # 00000000-0000-0000-0000-000000000002 (Marketing), and 00000000-0000-0000-0000-000000000003 (Retail), + # and use this value to initialize to HR each time. + init_team_id = "00000000-0000-0000-0000-000000000001" + print(f"Init team called, team_switched={team_switched}") + try: + authenticated_user = get_authenticated_user_details( + request_headers=request.headers + ) + user_id = authenticated_user["user_principal_id"] + if not user_id: + track_event_if_configured( + "UserIdNotFound", {"status_code": 400, "detail": "no user"} + ) + raise HTTPException(status_code=400, detail="no user") + + # Initialize memory store and service + memory_store = await DatabaseFactory.get_database(user_id=user_id) + team_service = TeamService(memory_store) + user_current_team = await memory_store.get_current_team(user_id=user_id) + if not user_current_team: + print("User has no current team, setting to default:", init_team_id) + user_current_team = await team_service.handle_team_selection( + user_id=user_id, team_id=init_team_id + ) + if user_current_team: + init_team_id = user_current_team.team_id + else: + init_team_id = user_current_team.team_id + # Verify the team exists and user has access to it + team_configuration = await team_service.get_team_configuration( + init_team_id, user_id + ) + if team_configuration is None: + raise HTTPException( + status_code=404, + detail=f"Team configuration '{init_team_id}' not found or access denied", + ) + + # Set as current team in memory + team_config.set_current_team( + user_id=user_id, team_configuration=team_configuration + ) + + # Initialize agent team for this user session + await OrchestrationManager.get_current_or_new_orchestration( + user_id=user_id, team_config=team_configuration, team_switched=team_switched + ) + + return { + "status": "Request started successfully", + "team_id": init_team_id, + "team": team_configuration, + } + + except Exception as e: + track_event_if_configured( + "InitTeamFailed", + { + "error": str(e), + }, + ) + raise HTTPException( + status_code=400, detail=f"Error starting request: {e}" + ) from e + + +@app_v3.post("/process_request") +async def process_request( + background_tasks: BackgroundTasks, input_task: InputTask, request: Request +): + """ + Create a new plan without full processing. + + --- + tags: + - Plans + parameters: + - name: user_principal_id + in: header + type: string + required: true + description: User ID extracted from the authentication header + - name: body + in: body + required: true + schema: + type: object + properties: + session_id: + type: string + description: Session ID for the plan + description: + type: string + description: The task description to validate and create plan for + responses: + 200: + description: Plan created successfully + schema: + type: object + properties: + plan_id: + type: string + description: The ID of the newly created plan + status: + type: string + description: Success message + session_id: + type: string + description: Session ID associated with the plan + 400: + description: RAI check failed or invalid input + schema: + type: object + properties: + detail: + type: string + description: Error message + """ + + if not await rai_success(input_task.description): + track_event_if_configured( + "RAI failed", + { + "status": "Plan not created - RAI check failed", + "description": input_task.description, + "session_id": input_task.session_id, + }, + ) + raise HTTPException( + status_code=400, + detail="Request contains content that doesn't meet our safety guidelines, try again.", + ) + + authenticated_user = get_authenticated_user_details(request_headers=request.headers) + user_id = authenticated_user["user_principal_id"] + + if not user_id: + track_event_if_configured( + "UserIdNotFound", {"status_code": 400, "detail": "no user"} + ) + raise HTTPException(status_code=400, detail="no user found") + + # if not input_task.team_id: + # track_event_if_configured( + # "TeamIDNofound", {"status_code": 400, "detail": "no team id"} + # ) + # raise HTTPException(status_code=400, detail="no team id") + + if not input_task.session_id: + input_task.session_id = str(uuid.uuid4()) + try: + plan_id = str(uuid.uuid4()) + # Initialize memory store and service + memory_store = await DatabaseFactory.get_database(user_id=user_id) + user_current_team = await memory_store.get_current_team(user_id=user_id) + team_id = None + if user_current_team: + team_id = user_current_team.team_id + team = await memory_store.get_team_by_id(team_id=team_id) + if not team: + raise HTTPException( + status_code=404, + detail=f"Team configuration '{team_id}' not found or access denied", + ) + plan = Plan( + id=plan_id, + plan_id=plan_id, + user_id=user_id, + session_id=input_task.session_id, + team_id=team_id, + initial_goal=input_task.description, + overall_status=PlanStatus.in_progress, + ) + await memory_store.add_plan(plan) + + track_event_if_configured( + "PlanCreated", + { + "status": "success", + "plan_id": plan.plan_id, + "session_id": input_task.session_id, + "user_id": user_id, + "team_id": team_id, + "description": input_task.description, + }, + ) + except Exception as e: + print(f"Error creating plan: {e}") + track_event_if_configured( + "PlanCreationFailed", + { + "status": "error", + "description": input_task.description, + "session_id": input_task.session_id, + "user_id": user_id, + "error": str(e), + }, + ) + raise HTTPException(status_code=500, detail="Failed to create plan") + + try: + # background_tasks.add_task( + # lambda: current_context.run(lambda:OrchestrationManager().run_orchestration, user_id, input_task) + # ) + + async def run_orchestration_task(): + await OrchestrationManager().run_orchestration(user_id, input_task) + + background_tasks.add_task(run_orchestration_task) + + return { + "status": "Request started successfully", + "session_id": input_task.session_id, + "plan_id": plan_id, + } + + except Exception as e: + track_event_if_configured( + "RequestStartFailed", + { + "session_id": input_task.session_id, + "description": input_task.description, + "error": str(e), + }, + ) + raise HTTPException( + status_code=400, detail=f"Error starting request: {e}" + ) from e + + +@app_v3.post("/plan_approval") +async def plan_approval( + human_feedback: messages.PlanApprovalResponse, request: Request +): + """ + Endpoint to receive plan approval or rejection from the user. + + --- + tags: + - Plans + parameters: + - name: user_principal_id + in: header + type: string + required: true + description: User ID extracted from the authentication header + requestBody: + description: Plan approval payload + required: true + content: + application/json: + schema: + type: object + properties: + m_plan_id: + type: string + description: The internal m_plan id for the plan (required) + approved: + type: boolean + description: Whether the plan is approved (true) or rejected (false) + feedback: + type: string + description: Optional feedback or comment from the user + plan_id: + type: string + description: Optional user-facing plan_id + responses: + 200: + description: Approval recorded successfully + content: + application/json: + schema: + type: object + properties: + status: + type: string + 401: + description: Missing or invalid user information + 404: + description: No active plan found for approval + 500: + description: Internal server error + """ + + authenticated_user = get_authenticated_user_details(request_headers=request.headers) + user_id = authenticated_user["user_principal_id"] + if not user_id: + raise HTTPException( + status_code=401, detail="Missing or invalid user information" + ) + # Set the approval in the orchestration config + try: + if user_id and human_feedback.m_plan_id: + if ( + orchestration_config + and human_feedback.m_plan_id in orchestration_config.approvals + ): + orchestration_config.set_approval_result( + human_feedback.m_plan_id, human_feedback.approved + ) + # orchestration_config.plans[human_feedback.m_plan_id][ + # "plan_id" + # ] = human_feedback.plan_id + print("Plan approval received:", human_feedback) + # print( + # "Updated orchestration config:", + # orchestration_config.plans[human_feedback.m_plan_id], + # ) + try: + result = await PlanService.handle_plan_approval( + human_feedback, user_id + ) + print("Plan approval processed:", result) + except ValueError as ve: + print(f"ValueError processing plan approval: {ve}") + except Exception as e: + print(f"Error processing plan approval: {e}") + track_event_if_configured( + "PlanApprovalReceived", + { + "plan_id": human_feedback.plan_id, + "m_plan_id": human_feedback.m_plan_id, + "approved": human_feedback.approved, + "user_id": user_id, + "feedback": human_feedback.feedback, + }, + ) + + return {"status": "approval recorded"} + else: + logging.warning( + f"No orchestration or plan found for plan_id: {human_feedback.m_plan_id}" + ) + raise HTTPException( + status_code=404, detail="No active plan found for approval" + ) + except Exception as e: + logging.error(f"Error processing plan approval: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@app_v3.post("/user_clarification") +async def user_clarification( + human_feedback: messages.UserClarificationResponse, request: Request +): + """ + Endpoint to receive user clarification responses for clarification requests sent by the system. + + --- + tags: + - Plans + parameters: + - name: user_principal_id + in: header + type: string + required: true + description: User ID extracted from the authentication header + requestBody: + description: User clarification payload + required: true + content: + application/json: + schema: + type: object + properties: + request_id: + type: string + description: The clarification request id sent by the system (required) + answer: + type: string + description: The user's answer or clarification text + plan_id: + type: string + description: (Optional) Associated plan_id + m_plan_id: + type: string + description: (Optional) Internal m_plan id + responses: + 200: + description: Clarification recorded successfully + 400: + description: RAI check failed or invalid input + 401: + description: Missing or invalid user information + 404: + description: No active plan found for clarification + 500: + description: Internal server error + """ + + authenticated_user = get_authenticated_user_details(request_headers=request.headers) + user_id = authenticated_user["user_principal_id"] + if not user_id: + raise HTTPException( + status_code=401, detail="Missing or invalid user information" + ) + # Set the approval in the orchestration config + if user_id and human_feedback.request_id: + # validate rai + if human_feedback.answer is not None or human_feedback.answer != "": + if not await rai_success(human_feedback.answer): + track_event_if_configured( + "RAI failed", + { + "status": "Plan Clarification ", + "description": human_feedback.answer, + "request_id": human_feedback.request_id, + }, + ) + raise HTTPException( + status_code=400, + detail={ + "error_type": "RAI_VALIDATION_FAILED", + "message": "Content Safety Check Failed", + "description": "Your request contains content that doesn't meet our safety guidelines. Please modify your request to ensure it's appropriate and try again.", + "suggestions": [ + "Remove any potentially harmful, inappropriate, or unsafe content", + "Use more professional and constructive language", + "Focus on legitimate business or educational objectives", + "Ensure your request complies with content policies", + ], + "user_action": "Please revise your request and try again", + }, + ) + + if ( + orchestration_config + and human_feedback.request_id in orchestration_config.clarifications + ): + # Use the new event-driven method to set clarification result + orchestration_config.set_clarification_result( + human_feedback.request_id, human_feedback.answer + ) + try: + result = await PlanService.handle_human_clarification( + human_feedback, user_id + ) + print("Human clarification processed:", result) + except ValueError as ve: + print(f"ValueError processing human clarification: {ve}") + except Exception as e: + print(f"Error processing human clarification: {e}") + track_event_if_configured( + "HumanClarificationReceived", + { + "request_id": human_feedback.request_id, + "answer": human_feedback.answer, + "user_id": user_id, + }, + ) + return { + "status": "clarification recorded", + } + else: + logging.warning( + f"No orchestration or plan found for request_id: {human_feedback.request_id}" + ) + raise HTTPException( + status_code=404, detail="No active plan found for clarification" + ) + + +@app_v3.post("/agent_message") +async def agent_message_user( + agent_message: messages.AgentMessageResponse, request: Request +): + """ + Endpoint to receive messages from agents (agent -> user communication). + + --- + tags: + - Agents + parameters: + - name: user_principal_id + in: header + type: string + required: true + description: User ID extracted from the authentication header + requestBody: + description: Agent message payload + required: true + content: + application/json: + schema: + type: object + properties: + plan_id: + type: string + description: ID of the plan this message relates to + agent: + type: string + description: Name or identifier of the agent sending the message + content: + type: string + description: The message content + agent_type: + type: string + description: Type of agent (AI/Human) + m_plan_id: + type: string + description: Optional internal m_plan id + responses: + 200: + description: Message recorded successfully + schema: + type: object + properties: + status: + type: string + 401: + description: Missing or invalid user information + """ + + authenticated_user = get_authenticated_user_details(request_headers=request.headers) + user_id = authenticated_user["user_principal_id"] + if not user_id: + raise HTTPException( + status_code=401, detail="Missing or invalid user information" + ) + # Set the approval in the orchestration config + + try: + + result = await PlanService.handle_agent_messages(agent_message, user_id) + print("Agent message processed:", result) + except ValueError as ve: + print(f"ValueError processing agent message: {ve}") + except Exception as e: + print(f"Error processing agent message: {e}") + + track_event_if_configured( + "AgentMessageReceived", + { + "agent": agent_message.agent, + "content": agent_message.content, + "user_id": user_id, + }, + ) + return { + "status": "message recorded", + } + + +@app_v3.post("/upload_team_config") +async def upload_team_config( + request: Request, + file: UploadFile = File(...), + team_id: Optional[str] = Query(None), +): + """ + Upload and save a team configuration JSON file. + + --- + tags: + - Team Configuration + parameters: + - name: user_principal_id + in: header + type: string + required: true + description: User ID extracted from the authentication header + - name: file + in: formData + type: file + required: true + description: JSON file containing team configuration + responses: + 200: + description: Team configuration uploaded successfully + 400: + description: Invalid request or file format + 401: + description: Missing or invalid user information + 500: + description: Internal server error + """ + # Validate user authentication + authenticated_user = get_authenticated_user_details(request_headers=request.headers) + user_id = authenticated_user["user_principal_id"] + if not user_id: + raise HTTPException( + status_code=401, detail="Missing or invalid user information" + ) + + # Validate file is provided and is JSON + if not file: + raise HTTPException(status_code=400, detail="No file provided") + + if not file.filename.endswith(".json"): + raise HTTPException(status_code=400, detail="File must be a JSON file") + + try: + # Read and parse JSON content + content = await file.read() + try: + json_data = json.loads(content.decode("utf-8")) + except json.JSONDecodeError as e: + raise HTTPException( + status_code=400, detail=f"Invalid JSON format: {str(e)}" + ) + + # Validate content with RAI before processing + if not team_id: + rai_valid, rai_error = await rai_validate_team_config(json_data) + if not rai_valid: + track_event_if_configured( + "Team configuration RAI validation failed", + { + "status": "failed", + "user_id": user_id, + "filename": file.filename, + "reason": rai_error, + }, + ) + raise HTTPException(status_code=400, detail=rai_error) + + track_event_if_configured( + "Team configuration RAI validation passed", + {"status": "passed", "user_id": user_id, "filename": file.filename}, + ) + # Initialize memory store and service + memory_store = await DatabaseFactory.get_database(user_id=user_id) + team_service = TeamService(memory_store) + + # Validate model deployments + models_valid, missing_models = await team_service.validate_team_models( + json_data + ) + if not models_valid: + error_message = ( + f"The following required models are not deployed in your Azure AI project: {', '.join(missing_models)}. " + f"Please deploy these models in Azure AI Foundry before uploading this team configuration." + ) + track_event_if_configured( + "Team configuration model validation failed", + { + "status": "failed", + "user_id": user_id, + "filename": file.filename, + "missing_models": missing_models, + }, + ) + raise HTTPException(status_code=400, detail=error_message) + + track_event_if_configured( + "Team configuration model validation passed", + {"status": "passed", "user_id": user_id, "filename": file.filename}, + ) + + # Validate search indexes + search_valid, search_errors = await team_service.validate_team_search_indexes( + json_data + ) + if not search_valid: + error_message = ( + f"Search index validation failed:\n\n{chr(10).join([f'β€’ {error}' for error in search_errors])}\n\n" + f"Please ensure all referenced search indexes exist in your Azure AI Search service." + ) + track_event_if_configured( + "Team configuration search validation failed", + { + "status": "failed", + "user_id": user_id, + "filename": file.filename, + "search_errors": search_errors, + }, + ) + raise HTTPException(status_code=400, detail=error_message) + + track_event_if_configured( + "Team configuration search validation passed", + {"status": "passed", "user_id": user_id, "filename": file.filename}, + ) + + # Validate and parse the team configuration + try: + team_config = await team_service.validate_and_parse_team_config( + json_data, user_id + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + # Save the configuration + try: + print("Saving team configuration...", team_id) + if team_id: + team_config.team_id = team_id + team_config.id = team_id # Ensure id is also set for updates + team_id = await team_service.save_team_configuration(team_config) + except ValueError as e: + raise HTTPException( + status_code=500, detail=f"Failed to save configuration: {str(e)}" + ) + + track_event_if_configured( + "Team configuration uploaded", + { + "status": "success", + "team_id": team_id, + "user_id": user_id, + "agents_count": len(team_config.agents), + "tasks_count": len(team_config.starting_tasks), + }, + ) + + return { + "status": "success", + "team_id": team_id, + "name": team_config.name, + "message": "Team configuration uploaded and saved successfully", + "team": team_config.model_dump(), # Return the full team configuration + } + + except HTTPException: + raise + except Exception as e: + logging.error(f"Unexpected error uploading team configuration: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error occurred") + + +@app_v3.get("/team_configs") +async def get_team_configs(request: Request): + """ + Retrieve all team configurations for the current user. + + --- + tags: + - Team Configuration + parameters: + - name: user_principal_id + in: header + type: string + required: true + description: User ID extracted from the authentication header + responses: + 200: + description: List of team configurations for the user + schema: + type: array + items: + type: object + properties: + id: + type: string + team_id: + type: string + name: + type: string + status: + type: string + created: + type: string + created_by: + type: string + description: + type: string + logo: + type: string + plan: + type: string + agents: + type: array + starting_tasks: + type: array + 401: + description: Missing or invalid user information + """ + # Validate user authentication + authenticated_user = get_authenticated_user_details(request_headers=request.headers) + user_id = authenticated_user["user_principal_id"] + if not user_id: + raise HTTPException( + status_code=401, detail="Missing or invalid user information" + ) + + try: + # Initialize memory store and service + memory_store = await DatabaseFactory.get_database(user_id=user_id) + team_service = TeamService(memory_store) + + # Retrieve all team configurations + team_configs = await team_service.get_all_team_configurations() + + # Convert to dictionaries for response + configs_dict = [config.model_dump() for config in team_configs] + + return configs_dict + + except Exception as e: + logging.error(f"Error retrieving team configurations: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error occurred") + + +@app_v3.get("/team_configs/{team_id}") +async def get_team_config_by_id(team_id: str, request: Request): + """ + Retrieve a specific team configuration by ID. + + --- + tags: + - Team Configuration + parameters: + - name: team_id + in: path + type: string + required: true + description: The ID of the team configuration to retrieve + - name: user_principal_id + in: header + type: string + required: true + description: User ID extracted from the authentication header + responses: + 200: + description: Team configuration details + schema: + type: object + properties: + id: + type: string + team_id: + type: string + name: + type: string + status: + type: string + created: + type: string + created_by: + type: string + description: + type: string + logo: + type: string + plan: + type: string + agents: + type: array + starting_tasks: + type: array + 401: + description: Missing or invalid user information + 404: + description: Team configuration not found + """ + # Validate user authentication + authenticated_user = get_authenticated_user_details(request_headers=request.headers) + user_id = authenticated_user["user_principal_id"] + if not user_id: + raise HTTPException( + status_code=401, detail="Missing or invalid user information" + ) + + try: + # Initialize memory store and service + memory_store = await DatabaseFactory.get_database(user_id=user_id) + team_service = TeamService(memory_store) + + # Retrieve the specific team configuration + team_config = await team_service.get_team_configuration(team_id, user_id) + + if team_config is None: + raise HTTPException(status_code=404, detail="Team configuration not found") + + # Convert to dictionary for response + return team_config.model_dump() + + except HTTPException: + # Re-raise HTTP exceptions + raise + except Exception as e: + logging.error(f"Error retrieving team configuration: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error occurred") + + +@app_v3.delete("/team_configs/{team_id}") +async def delete_team_config(team_id: str, request: Request): + """ + Delete a team configuration by ID. + + --- + tags: + - Team Configuration + parameters: + - name: team_id + in: path + type: string + required: true + description: The ID of the team configuration to delete + - name: user_principal_id + in: header + type: string + required: true + description: User ID extracted from the authentication header + responses: + 200: + description: Team configuration deleted successfully + schema: + type: object + properties: + status: + type: string + message: + type: string + team_id: + type: string + 401: + description: Missing or invalid user information + 404: + description: Team configuration not found + """ + # Validate user authentication + authenticated_user = get_authenticated_user_details(request_headers=request.headers) + user_id = authenticated_user["user_principal_id"] + if not user_id: + raise HTTPException( + status_code=401, detail="Missing or invalid user information" + ) + + try: + # To do: Check if the team is the users current team, or if it is + # used in any active sessions/plans. Refuse request if so. + + # Initialize memory store and service + memory_store = await DatabaseFactory.get_database(user_id=user_id) + team_service = TeamService(memory_store) + + # Delete the team configuration + deleted = await team_service.delete_team_configuration(team_id, user_id) + + if not deleted: + raise HTTPException(status_code=404, detail="Team configuration not found") + + # Track the event + track_event_if_configured( + "Team configuration deleted", + {"status": "success", "team_id": team_id, "user_id": user_id}, + ) + + return { + "status": "success", + "message": "Team configuration deleted successfully", + "team_id": team_id, + } + + except HTTPException: + # Re-raise HTTP exceptions + raise + except Exception as e: + logging.error(f"Error deleting team configuration: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error occurred") + + +@app_v3.post("/select_team") +async def select_team(selection: TeamSelectionRequest, request: Request): + """ + Select the current team for the user session. + """ + # Validate user authentication + authenticated_user = get_authenticated_user_details(request_headers=request.headers) + user_id = authenticated_user["user_principal_id"] + if not user_id: + raise HTTPException( + status_code=401, detail="Missing or invalid user information" + ) + + if not selection.team_id: + raise HTTPException(status_code=400, detail="Team ID is required") + + try: + # Initialize memory store and service + memory_store = await DatabaseFactory.get_database(user_id=user_id) + team_service = TeamService(memory_store) + + # Verify the team exists and user has access to it + team_configuration = await team_service.get_team_configuration( + selection.team_id, user_id + ) + if team_configuration is None: # ensure that id is valid + raise HTTPException( + status_code=404, + detail=f"Team configuration '{selection.team_id}' not found or access denied", + ) + set_team = await team_service.handle_team_selection( + user_id=user_id, team_id=selection.team_id + ) + if not set_team: + track_event_if_configured( + "Team selected", + { + "status": "failed", + "team_id": selection.team_id, + "team_name": team_configuration.name, + "user_id": user_id, + }, + ) + raise HTTPException( + status_code=404, + detail=f"Team configuration '{selection.team_id}' failed to set", + ) + + # save to in-memory config for current user + team_config.set_current_team( + user_id=user_id, team_configuration=team_configuration + ) + + # Track the team selection event + track_event_if_configured( + "Team selected", + { + "status": "success", + "team_id": selection.team_id, + "team_name": team_configuration.name, + "user_id": user_id, + }, + ) + + return { + "status": "success", + "message": f"Team '{team_configuration.name}' selected successfully", + "team_id": selection.team_id, + "team_name": team_configuration.name, + "agents_count": len(team_configuration.agents), + "team_description": team_configuration.description, + } + + except HTTPException: + # Re-raise HTTP exceptions + raise + except Exception as e: + logging.error(f"Error selecting team: {str(e)}") + track_event_if_configured( + "Team selection error", + { + "status": "error", + "team_id": selection.team_id, + "user_id": user_id, + "error": str(e), + }, + ) + raise HTTPException(status_code=500, detail="Internal server error occurred") + + +# Get plans is called in the initial side rendering of the frontend +@app_v3.get("/plans") +async def get_plans(request: Request): + """ + Retrieve plans for the current user. + + --- + tags: + - Plans + parameters: + - name: session_id + in: query + type: string + required: false + description: Optional session ID to retrieve plans for a specific session + responses: + 200: + description: List of plans with steps for the user + schema: + type: array + items: + type: object + properties: + id: + type: string + description: Unique ID of the plan + session_id: + type: string + description: Session ID associated with the plan + initial_goal: + type: string + description: The initial goal derived from the user's input + overall_status: + type: string + description: Status of the plan (e.g., in_progress, completed) + steps: + type: array + items: + type: object + properties: + id: + type: string + description: Unique ID of the step + plan_id: + type: string + description: ID of the plan the step belongs to + action: + type: string + description: The action to be performed + agent: + type: string + description: The agent responsible for the step + status: + type: string + description: Status of the step (e.g., planned, approved, completed) + 400: + description: Missing or invalid user information + 404: + description: Plan not found + """ + + authenticated_user = get_authenticated_user_details(request_headers=request.headers) + user_id = authenticated_user["user_principal_id"] + if not user_id: + track_event_if_configured( + "UserIdNotFound", {"status_code": 400, "detail": "no user"} + ) + raise HTTPException(status_code=400, detail="no user") + + # Replace the following with code to get plan run history from the database + + # Initialize memory context + memory_store = await DatabaseFactory.get_database(user_id=user_id) + + current_team = await memory_store.get_current_team(user_id=user_id) + if not current_team: + return [] + + all_plans = await memory_store.get_all_plans_by_team_id_status( + user_id=user_id, team_id=current_team.team_id, status=PlanStatus.completed + ) + + return all_plans + + +# Get plans is called in the initial side rendering of the frontend +@app_v3.get("/plan") +async def get_plan_by_id( + request: Request, + plan_id: Optional[str] = Query(None), +): + """ + Retrieve plans for the current user. + + --- + tags: + - Plans + parameters: + - name: session_id + in: query + type: string + required: false + description: Optional session ID to retrieve plans for a specific session + responses: + 200: + description: List of plans with steps for the user + schema: + type: array + items: + type: object + properties: + id: + type: string + description: Unique ID of the plan + session_id: + type: string + description: Session ID associated with the plan + initial_goal: + type: string + description: The initial goal derived from the user's input + overall_status: + type: string + description: Status of the plan (e.g., in_progress, completed) + steps: + type: array + items: + type: object + properties: + id: + type: string + description: Unique ID of the step + plan_id: + type: string + description: ID of the plan the step belongs to + action: + type: string + description: The action to be performed + agent: + type: string + description: The agent responsible for the step + status: + type: string + description: Status of the step (e.g., planned, approved, completed) + 400: + description: Missing or invalid user information + 404: + description: Plan not found + """ + + authenticated_user = get_authenticated_user_details(request_headers=request.headers) + user_id = authenticated_user["user_principal_id"] + if not user_id: + track_event_if_configured( + "UserIdNotFound", {"status_code": 400, "detail": "no user"} + ) + raise HTTPException(status_code=400, detail="no user") + + # Replace the following with code to get plan run history from the database + + # Initialize memory context + memory_store = await DatabaseFactory.get_database(user_id=user_id) + try: + if plan_id: + plan = await memory_store.get_plan_by_plan_id(plan_id=plan_id) + if not plan: + track_event_if_configured( + "GetPlanBySessionNotFound", + {"status_code": 400, "detail": "Plan not found"}, + ) + raise HTTPException(status_code=404, detail="Plan not found") + + # Use get_steps_by_plan to match the original implementation + + team = await memory_store.get_team_by_id(team_id=plan.team_id) + agent_messages = await memory_store.get_agent_messages(plan_id=plan.plan_id) + mplan = plan.m_plan if plan.m_plan else None + streaming_message = plan.streaming_message if plan.streaming_message else "" + plan.streaming_message = "" # clear streaming message after retrieval + plan.m_plan = None # remove m_plan from plan object for response + return { + "plan": plan, + "team": team if team else None, + "messages": agent_messages, + "m_plan": mplan, + "streaming_message": streaming_message, + } + else: + track_event_if_configured( + "GetPlanId", {"status_code": 400, "detail": "no plan id"} + ) + raise HTTPException(status_code=400, detail="no plan id") + except Exception as e: + logging.error(f"Error retrieving plan: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error occurred") diff --git a/src/backend/af/callbacks/__init__.py b/src/backend/af/callbacks/__init__.py new file mode 100644 index 000000000..35deaed79 --- /dev/null +++ b/src/backend/af/callbacks/__init__.py @@ -0,0 +1 @@ +# Callbacks package for handling agent responses and streaming diff --git a/src/backend/af/callbacks/global_debug.py b/src/backend/af/callbacks/global_debug.py new file mode 100644 index 000000000..3da87681f --- /dev/null +++ b/src/backend/af/callbacks/global_debug.py @@ -0,0 +1,14 @@ +class DebugGlobalAccess: + """Class to manage global access to the Magentic orchestration manager.""" + + _managers = [] + + @classmethod + def add_manager(cls, manager): + """Add a new manager to the global list.""" + cls._managers.append(manager) + + @classmethod + def get_managers(cls): + """Get the list of all managers.""" + return cls._managers diff --git a/src/backend/af/callbacks/response_handlers.py b/src/backend/af/callbacks/response_handlers.py new file mode 100644 index 000000000..699f06a0f --- /dev/null +++ b/src/backend/af/callbacks/response_handlers.py @@ -0,0 +1,206 @@ +""" +Agent Framework response callbacks for employee onboarding / multi-agent system. +Replaces Semantic Kernel message types with agent_framework ChatResponseUpdate handling. +""" + +import asyncio +import json +import logging +import re +import time +from typing import Optional + +from agent_framework import ( + ChatResponseUpdate, + FunctionCallContent, + UsageContent, + Role, + TextContent, +) + +from af.config.settings import connection_config +from af.models.messages import ( + AgentMessage, + AgentMessageStreaming, + AgentToolCall, + AgentToolMessage, + WebsocketMessageType, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Utility +# --------------------------------------------------------------------------- + +_CITATION_PATTERNS = [ + (r"\[\d+:\d+\|source\]", ""), # [9:0|source] + (r"\[\s*source\s*\]", ""), # [source] + (r"\[\d+\]", ""), # [12] + (r"【[^】]*】", ""), # Unicode bracket citations + (r"\(source:[^)]*\)", ""), # (source: xyz) + (r"\[source:[^\]]*\]", ""), # [source: xyz] +] + + +def clean_citations(text: str) -> str: + """Remove citation markers from agent responses while preserving formatting.""" + if not text: + return text + for pattern, repl in _CITATION_PATTERNS: + text = re.sub(pattern, repl, text, flags=re.IGNORECASE) + return text + + +def _parse_function_arguments(arg_value: Optional[str | dict]) -> dict: + """Best-effort parse for function call arguments (stringified JSON or dict).""" + if arg_value is None: + return {} + if isinstance(arg_value, dict): + return arg_value + if isinstance(arg_value, str): + try: + return json.loads(arg_value) + except Exception: # noqa: BLE001 + return {"raw": arg_value} + return {"raw": str(arg_value)} + + +# --------------------------------------------------------------------------- +# Core handlers +# --------------------------------------------------------------------------- + +def agent_framework_update_callback( + update: ChatResponseUpdate, + user_id: Optional[str] = None, +) -> None: + """ + Handle a non-streaming perspective of updates (tool calls, intermediate steps, final usage). + This can be called for each ChatResponseUpdate; it will route tool calls and standard text + messages to WebSocket. + """ + agent_name = getattr(update, "model_id", None) or "Agent" + # Use Role or fallback + role = getattr(update, "role", Role.ASSISTANT) + + # Detect tool/function calls + function_call_contents = [ + c for c in (update.contents or []) + if isinstance(c, FunctionCallContent) + ] + + if user_id is None: + return + + try: + if function_call_contents: + # Build tool message + tool_message = AgentToolMessage(agent_name=agent_name) + for fc in function_call_contents: + args = _parse_function_arguments(getattr(fc, "arguments", None)) + tool_message.tool_calls.append( + AgentToolCall( + tool_name=getattr(fc, "name", "unknown_tool"), + arguments=args, + ) + ) + asyncio.create_task( + connection_config.send_status_update_async( + tool_message, + user_id, + message_type=WebsocketMessageType.AGENT_TOOL_MESSAGE, + ) + ) + logger.info("Function call(s) dispatched: %s", tool_message) + return + + # Ignore pure usage or empty updates (handled as final in streaming handler) + if any(isinstance(c, UsageContent) for c in (update.contents or [])): + # We'll treat this as a final token accounting event; no standard message needed. + logger.debug("UsageContent received (final accounting); skipping text dispatch.") + return + + # Standard assistant/user message (non-stream delta) + if update.text: + final_message = AgentMessage( + agent_name=agent_name, + timestamp=str(time.time()), + content=clean_citations(update.text), + ) + asyncio.create_task( + connection_config.send_status_update_async( + final_message, + user_id, + message_type=WebsocketMessageType.AGENT_MESSAGE, + ) + ) + logger.info("%s message: %s", role.name.capitalize(), final_message) + + except Exception as e: # noqa: BLE001 + logger.error("agent_framework_update_callback: Error sending WebSocket message: %s", e) + + +async def streaming_agent_framework_callback( + update: ChatResponseUpdate, + user_id: Optional[str] = None, +) -> None: + """ + Handle streaming deltas. For each update with text, forward a streaming message. + Mark is_final=True when a UsageContent is observed (end of run). + """ + if user_id is None: + return + + try: + # Determine if this update marks the end + is_final = any(isinstance(c, UsageContent) for c in (update.contents or [])) + + # Streaming text can appear either in update.text or inside TextContent entries. + pieces: list[str] = [] + if update.text: + pieces.append(update.text) + # Some events may provide TextContent objects without setting update.text + for c in (update.contents or []): + if isinstance(c, TextContent) and getattr(c, "text", None): + pieces.append(c.text) + + if not pieces: + return + + streaming_message = AgentMessageStreaming( + agent_name=getattr(update, "model_id", None) or "Agent", + content=clean_citations("".join(pieces)), + is_final=is_final, + ) + + await connection_config.send_status_update_async( + streaming_message, + user_id, + message_type=WebsocketMessageType.AGENT_MESSAGE_STREAMING, + ) + + if is_final: + logger.info("Final streaming chunk sent for agent '%s'", streaming_message.agent_name) + + except Exception as e: # noqa: BLE001 + logger.error("streaming_agent_framework_callback: Error sending streaming WebSocket message: %s", e) + + +# --------------------------------------------------------------------------- +# Convenience wrappers (optional) +# --------------------------------------------------------------------------- + +def handle_update(update: ChatResponseUpdate, user_id: Optional[str]) -> None: + """ + Unified entry point if caller doesn't distinguish streaming vs non-streaming. + You can call this once per update. It will: + - Forward streaming text increments + - Forward tool calls + - Skip purely usage-only events (except marking final in streaming) + """ + # Send streaming chunk first (async context) + asyncio.create_task(streaming_agent_framework_callback(update, user_id)) + # Then send non-stream items (tool calls or discrete messages) + agent_framework_update_callback(update, user_id) + diff --git a/src/backend/af/common/services/__init__.py b/src/backend/af/common/services/__init__.py new file mode 100644 index 000000000..4c07712c9 --- /dev/null +++ b/src/backend/af/common/services/__init__.py @@ -0,0 +1,19 @@ +"""Service abstractions for v3. + +Exports: +- BaseAPIService: minimal async HTTP wrapper using endpoints from AppConfig +- MCPService: service targeting a local/remote MCP server +- FoundryService: helper around Azure AI Foundry (AIProjectClient) +""" + +from .agents_service import AgentsService +from .base_api_service import BaseAPIService +from .foundry_service import FoundryService +from .mcp_service import MCPService + +__all__ = [ + "BaseAPIService", + "MCPService", + "FoundryService", + "AgentsService", +] diff --git a/src/backend/af/common/services/agents_service.py b/src/backend/af/common/services/agents_service.py new file mode 100644 index 000000000..fc4e7fa06 --- /dev/null +++ b/src/backend/af/common/services/agents_service.py @@ -0,0 +1,121 @@ +""" +AgentsService (skeleton) + +Lightweight service that receives a TeamService instance and exposes helper +methods to convert a TeamConfiguration into a list/array of agent descriptors. + +This is intentionally a simple skeleton β€” the user will later provide the +implementation that wires these descriptors into Semantic Kernel / Foundry +agent instances. +""" + +import logging +from typing import Any, Dict, List, Union + +from common.models.messages_kernel import TeamAgent, TeamConfiguration +from v3.common.services.team_service import TeamService + + +class AgentsService: + """Service for building agent descriptors from a team configuration. + + Responsibilities (skeleton): + - Receive a TeamService instance on construction (can be used for validation + or lookups when needed). + - Expose a method that accepts a TeamConfiguration (or raw dict) and + returns a list of agent descriptors. Descriptors are plain dicts that + contain the fields required to later instantiate runtime agents. + + The concrete instantiation logic (semantic kernel / foundry) is intentionally + left out and should be implemented by the user later (see + `instantiate_agents` placeholder). + """ + + def __init__(self, team_service: TeamService): + self.team_service = team_service + self.logger = logging.getLogger(__name__) + + async def get_agents_from_team_config( + self, team_config: Union[TeamConfiguration, Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Return a list of lightweight agent descriptors derived from a + TeamConfiguration or a raw dict. + + Each descriptor contains the basic fields from the team config and a + placeholder where a future runtime/agent object can be attached. + + Args: + team_config: TeamConfiguration model instance or a raw dict + + Returns: + List[dict] -- each dict contains keys like: + - input_key, type, name, system_message, description, icon, + index_name, agent_obj (placeholder) + """ + if not team_config: + return [] + + # Accept either the pydantic TeamConfiguration or a raw dictionary + if hasattr(team_config, "agents"): + agents_raw = team_config.agents or [] + elif isinstance(team_config, dict): + agents_raw = team_config.get("agents", []) + else: + # Unknown type; try to coerce to a list + try: + agents_raw = list(team_config) + except Exception: + agents_raw = [] + + descriptors: List[Dict[str, Any]] = [] + for a in agents_raw: + if isinstance(a, TeamAgent): + desc = { + "input_key": a.input_key, + "type": a.type, + "name": a.name, + "system_message": getattr(a, "system_message", ""), + "description": getattr(a, "description", ""), + "icon": getattr(a, "icon", ""), + "index_name": getattr(a, "index_name", ""), + "use_rag": getattr(a, "use_rag", False), + "use_mcp": getattr(a, "use_mcp", False), + "coding_tools": getattr(a, "coding_tools", False), + # Placeholder for later wiring to a runtime/agent instance + "agent_obj": None, + } + elif isinstance(a, dict): + desc = { + "input_key": a.get("input_key"), + "type": a.get("type"), + "name": a.get("name"), + "system_message": a.get("system_message") or a.get("instructions"), + "description": a.get("description"), + "icon": a.get("icon"), + "index_name": a.get("index_name"), + "use_rag": a.get("use_rag", False), + "use_mcp": a.get("use_mcp", False), + "coding_tools": a.get("coding_tools", False), + "agent_obj": None, + } + else: + # Fallback: keep raw object for later introspection + desc = {"raw": a, "agent_obj": None} + + descriptors.append(desc) + + return descriptors + + async def instantiate_agents(self, agent_descriptors: List[Dict[str, Any]]): + """Placeholder for instantiating runtime agent objects from descriptors. + + The real implementation should create Semantic Kernel / Foundry agents + and attach them to each descriptor under the key `agent_obj` or return a + list of instantiated agents. + + Raises: + NotImplementedError -- this is only a skeleton. + """ + raise NotImplementedError( + "Agent instantiation is not implemented in the skeleton" + ) diff --git a/src/backend/af/common/services/base_api_service.py b/src/backend/af/common/services/base_api_service.py new file mode 100644 index 000000000..8f8b48ef1 --- /dev/null +++ b/src/backend/af/common/services/base_api_service.py @@ -0,0 +1,114 @@ +from typing import Any, Dict, Optional, Union + +import aiohttp +from common.config.app_config import config + + +class BaseAPIService: + """Minimal async HTTP API service. + + - Reads base endpoints from AppConfig using `from_config` factory. + - Provides simple GET/POST helpers with JSON payloads. + - Designed to be subclassed (e.g., MCPService, FoundryService). + """ + + def __init__( + self, + base_url: str, + *, + default_headers: Optional[Dict[str, str]] = None, + timeout_seconds: int = 30, + session: Optional[aiohttp.ClientSession] = None, + ) -> None: + if not base_url: + raise ValueError("base_url is required") + self.base_url = base_url.rstrip("/") + self.default_headers = default_headers or {} + self.timeout = aiohttp.ClientTimeout(total=timeout_seconds) + self._session_external = session is not None + self._session: Optional[aiohttp.ClientSession] = session + + @classmethod + def from_config( + cls, + endpoint_attr: str, + *, + default: Optional[str] = None, + **kwargs: Any, + ) -> "BaseAPIService": + """Create a service using an endpoint attribute from AppConfig. + + Args: + endpoint_attr: Name of the attribute on AppConfig (e.g., 'AZURE_AI_AGENT_ENDPOINT'). + default: Optional default if attribute missing or empty. + **kwargs: Passed through to the constructor. + """ + base_url = getattr(config, endpoint_attr, None) or default + if not base_url: + raise ValueError( + f"Endpoint '{endpoint_attr}' not configured in AppConfig and no default provided" + ) + return cls(base_url, **kwargs) + + async def _ensure_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession(timeout=self.timeout) + return self._session + + def _url(self, path: str) -> str: + path = path or "" + if not path: + return self.base_url + return f"{self.base_url}/{path.lstrip('/')}" + + async def _request( + self, + method: str, + path: str = "", + *, + headers: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, Union[str, int, float]]] = None, + json: Optional[Dict[str, Any]] = None, + ) -> aiohttp.ClientResponse: + session = await self._ensure_session() + url = self._url(path) + merged_headers = {**self.default_headers, **(headers or {})} + return await session.request( + method.upper(), url, headers=merged_headers, params=params, json=json + ) + + async def get_json( + self, + path: str = "", + *, + headers: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, Union[str, int, float]]] = None, + ) -> Any: + resp = await self._request("GET", path, headers=headers, params=params) + resp.raise_for_status() + return await resp.json() + + async def post_json( + self, + path: str = "", + *, + headers: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, Union[str, int, float]]] = None, + json: Optional[Dict[str, Any]] = None, + ) -> Any: + resp = await self._request( + "POST", path, headers=headers, params=params, json=json + ) + resp.raise_for_status() + return await resp.json() + + async def close(self) -> None: + if self._session and not self._session.closed and not self._session_external: + await self._session.close() + + async def __aenter__(self) -> "BaseAPIService": + await self._ensure_session() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() diff --git a/src/backend/af/common/services/foundry_service.py b/src/backend/af/common/services/foundry_service.py new file mode 100644 index 000000000..563f5c56c --- /dev/null +++ b/src/backend/af/common/services/foundry_service.py @@ -0,0 +1,116 @@ +import logging +import re +from typing import Any, Dict, List + +# from git import List +import aiohttp +from azure.ai.projects.aio import AIProjectClient +from common.config.app_config import config + + +class FoundryService: + """Helper around Azure AI Foundry's AIProjectClient. + + Uses AppConfig.get_ai_project_client() to obtain a properly configured + asynchronous client. Provides a small set of convenience methods and + can be extended for specific project operations. + """ + + def __init__(self, client: AIProjectClient | None = None) -> None: + self._client = client + self.logger = logging.getLogger(__name__) + # Model validation configuration + self.subscription_id = config.AZURE_AI_SUBSCRIPTION_ID + self.resource_group = config.AZURE_AI_RESOURCE_GROUP + self.project_name = config.AZURE_AI_PROJECT_NAME + self.project_endpoint = config.AZURE_AI_PROJECT_ENDPOINT + + async def get_client(self) -> AIProjectClient: + if self._client is None: + self._client = config.get_ai_project_client() + return self._client + + # Example convenience wrappers – adjust as your project needs evolve + async def list_connections(self) -> list[Dict[str, Any]]: + client = await self.get_client() + conns = await client.connections.list() + return [c.as_dict() if hasattr(c, "as_dict") else dict(c) for c in conns] + + async def get_connection(self, name: str) -> Dict[str, Any]: + client = await self.get_client() + conn = await client.connections.get(name=name) + return conn.as_dict() if hasattr(conn, "as_dict") else dict(conn) + + # ----------------------- + # Model validation methods + # ----------------------- + async def list_model_deployments(self) -> List[Dict[str, Any]]: + """ + List all model deployments in the Azure AI project using the REST API. + """ + if not all([self.subscription_id, self.resource_group, self.project_name]): + self.logger.error("Azure AI project configuration is incomplete") + return [] + + try: + # Get Azure Management API token (not Cognitive Services token) + credential = config.get_azure_credentials() + token = credential.get_token(config.AZURE_MANAGEMENT_SCOPE) + + # Extract Azure OpenAI resource name from endpoint URL + openai_endpoint = config.AZURE_OPENAI_ENDPOINT + # Extract resource name from URL like "https://aisa-macae-d3x6aoi7uldi.openai.azure.com/" + match = re.search(r"https://([^.]+)\.openai\.azure\.com", openai_endpoint) + if not match: + self.logger.error( + f"Could not extract resource name from endpoint: {openai_endpoint}" + ) + return [] + + openai_resource_name = match.group(1) + self.logger.info(f"Using Azure OpenAI resource: {openai_resource_name}") + + # Query Azure OpenAI resource deployments + url = ( + f"https://management.azure.com/subscriptions/{self.subscription_id}/" + f"resourceGroups/{self.resource_group}/providers/Microsoft.CognitiveServices/" + f"accounts/{openai_resource_name}/deployments" + ) + + headers = { + "Authorization": f"Bearer {token.token}", + "Content-Type": "application/json", + } + params = {"api-version": "2024-10-01"} + + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers, params=params) as response: + if response.status == 200: + data = await response.json() + deployments = data.get("value", []) + deployment_info: List[Dict[str, Any]] = [] + for deployment in deployments: + deployment_info.append( + { + "name": deployment.get("name"), + "model": deployment.get("properties", {}).get( + "model", {} + ), + "status": deployment.get("properties", {}).get( + "provisioningState" + ), + "endpoint_uri": deployment.get( + "properties", {} + ).get("scoringUri"), + } + ) + return deployment_info + else: + error_text = await response.text() + self.logger.error( + f"Failed to list deployments. Status: {response.status}, Error: {error_text}" + ) + return [] + except Exception as e: + self.logger.error(f"Error listing model deployments: {e}") + return [] diff --git a/src/backend/af/common/services/mcp_service.py b/src/backend/af/common/services/mcp_service.py new file mode 100644 index 000000000..5bdf323cc --- /dev/null +++ b/src/backend/af/common/services/mcp_service.py @@ -0,0 +1,37 @@ +from typing import Any, Dict, Optional + +from common.config.app_config import config + +from .base_api_service import BaseAPIService + + +class MCPService(BaseAPIService): + """Service for interacting with an MCP server. + + Base URL is taken from AppConfig.MCP_SERVER_ENDPOINT if present, + otherwise falls back to v3 MCP default in settings or localhost. + """ + + def __init__(self, base_url: str, *, token: Optional[str] = None, **kwargs): + headers = {"Content-Type": "application/json"} + if token: + headers["Authorization"] = f"Bearer {token}" + super().__init__(base_url, default_headers=headers, **kwargs) + + @classmethod + def from_app_config(cls, **kwargs) -> "MCPService": + # Prefer explicit MCP endpoint if defined; otherwise use the v3 settings default. + endpoint = config.MCP_SERVER_ENDPOINT + if not endpoint: + # fall back to typical local dev default + return None # or handle the error appropriately + token = None # add token retrieval if you enable auth later + return cls(endpoint, token=token, **kwargs) + + async def health(self) -> Dict[str, Any]: + return await self.get_json("health") + + async def invoke_tool( + self, tool_name: str, payload: Dict[str, Any] + ) -> Dict[str, Any]: + return await self.post_json(f"tools/{tool_name}", json=payload) diff --git a/src/backend/af/common/services/plan_service.py b/src/backend/af/common/services/plan_service.py new file mode 100644 index 000000000..ff7d5b30e --- /dev/null +++ b/src/backend/af/common/services/plan_service.py @@ -0,0 +1,254 @@ +import json +import logging +from dataclasses import asdict + +import v3.models.messages as messages +from common.database.database_factory import DatabaseFactory +from common.models.messages_kernel import ( + AgentMessageData, + AgentMessageType, + AgentType, + PlanStatus, +) +from common.utils.event_utils import track_event_if_configured +from v3.config.settings import orchestration_config + +logger = logging.getLogger(__name__) + + +def build_agent_message_from_user_clarification( + human_feedback: messages.UserClarificationResponse, user_id: str +) -> AgentMessageData: + """ + Convert a UserClarificationResponse (human feedback) into an AgentMessageData. + """ + # NOTE: AgentMessageType enum currently defines values with trailing commas in messages_kernel.py. + # e.g. HUMAN_AGENT = "Human_Agent", -> value becomes ('Human_Agent',) + # Consider fixing that enum (remove trailing commas) so .value is a string. + return AgentMessageData( + plan_id=human_feedback.plan_id or "", + user_id=user_id, + m_plan_id=human_feedback.m_plan_id or None, + agent=AgentType.HUMAN.value, # or simply "Human_Agent" + agent_type=AgentMessageType.HUMAN_AGENT, # will serialize per current enum definition + content=human_feedback.answer or "", + raw_data=json.dumps(asdict(human_feedback)), + steps=[], # intentionally empty + next_steps=[], # intentionally empty + ) + + +def build_agent_message_from_agent_message_response( + agent_response: messages.AgentMessageResponse, + user_id: str, +) -> AgentMessageData: + """ + Convert a messages.AgentMessageResponse into common.models.messages_kernel.AgentMessageData. + This is defensive: it tolerates missing fields and different timestamp formats. + """ + # Robust timestamp parsing (accepts seconds or ms or missing) + + # Raw data serialization + raw = getattr(agent_response, "raw_data", None) + try: + if raw is None: + # try asdict if it's a dataclass-like + try: + raw_str = json.dumps(asdict(agent_response)) + except Exception: + raw_str = json.dumps( + { + k: getattr(agent_response, k) + for k in dir(agent_response) + if not k.startswith("_") + } + ) + elif isinstance(raw, (dict, list)): + raw_str = json.dumps(raw) + else: + raw_str = str(raw) + except Exception: + raw_str = json.dumps({"raw": str(raw)}) + + # Steps / next_steps defaulting + steps = getattr(agent_response, "steps", []) or [] + next_steps = getattr(agent_response, "next_steps", []) or [] + + # Agent name and type + agent_name = ( + getattr(agent_response, "agent", "") + or getattr(agent_response, "agent_name", "") + or getattr(agent_response, "source", "") + ) + # Try to infer agent_type, fallback to AI_AGENT + agent_type_raw = getattr(agent_response, "agent_type", None) + if isinstance(agent_type_raw, AgentMessageType): + agent_type = agent_type_raw + else: + # Normalize common strings + agent_type_str = str(agent_type_raw or "").lower() + if "human" in agent_type_str: + agent_type = AgentMessageType.HUMAN_AGENT + else: + agent_type = AgentMessageType.AI_AGENT + + # Content + content = ( + getattr(agent_response, "content", "") + or getattr(agent_response, "text", "") + or "" + ) + + # plan_id / user_id fallback + plan_id_val = getattr(agent_response, "plan_id", "") or "" + user_id_val = getattr(agent_response, "user_id", "") or user_id + + return AgentMessageData( + plan_id=plan_id_val, + user_id=user_id_val, + m_plan_id=getattr(agent_response, "m_plan_id", ""), + agent=agent_name, + agent_type=agent_type, + content=content, + raw_data=raw_str, + steps=list(steps), + next_steps=list(next_steps), + ) + + +class PlanService: + + @staticmethod + async def handle_plan_approval( + human_feedback: messages.PlanApprovalResponse, user_id: str + ) -> bool: + """ + Process a PlanApprovalResponse coming from the client. + + Args: + feedback: messages.PlanApprovalResponse (contains m_plan_id, plan_id, approved, feedback) + user_id: authenticated user id + + Returns: + dict with status and metadata + + Raises: + ValueError on invalid state + """ + if orchestration_config is None: + return False + try: + mplan = orchestration_config.plans[human_feedback.m_plan_id] + memory_store = await DatabaseFactory.get_database(user_id=user_id) + if hasattr(mplan, "plan_id"): + print( + "Updated orchestration config:", + orchestration_config.plans[human_feedback.m_plan_id], + ) + if human_feedback.approved: + plan = await memory_store.get_plan(human_feedback.plan_id) + mplan.plan_id = human_feedback.plan_id + mplan.team_id = plan.team_id # just to keep consistency + orchestration_config.plans[human_feedback.m_plan_id] = mplan + if plan: + plan.overall_status = PlanStatus.approved + plan.m_plan = mplan.model_dump() + await memory_store.update_plan(plan) + track_event_if_configured( + "PlanApproved", + { + "m_plan_id": human_feedback.m_plan_id, + "plan_id": human_feedback.plan_id, + "user_id": user_id, + }, + ) + else: + print("Plan not found in memory store.") + return False + else: # reject plan + track_event_if_configured( + "PlanRejected", + { + "m_plan_id": human_feedback.m_plan_id, + "plan_id": human_feedback.plan_id, + "user_id": user_id, + }, + ) + await memory_store.delete_plan_by_plan_id(human_feedback.plan_id) + + except Exception as e: + print(f"Error processing plan approval: {e}") + return False + return True + + @staticmethod + async def handle_agent_messages( + agent_message: messages.AgentMessageResponse, user_id: str + ) -> bool: + """ + Process an AgentMessage coming from the client. + + Args: + standard_message: messages.AgentMessage (contains relevant message data) + user_id: authenticated user id + + Returns: + dict with status and metadata + + Raises: + ValueError on invalid state + """ + try: + agent_msg = build_agent_message_from_agent_message_response( + agent_message, user_id + ) + + # Persist if your database layer supports it. + # Look for or implement something like: memory_store.add_agent_message(agent_msg) + memory_store = await DatabaseFactory.get_database(user_id=user_id) + await memory_store.add_agent_message(agent_msg) + if agent_message.is_final: + plan = await memory_store.get_plan(agent_msg.plan_id) + plan.streaming_message = agent_message.streaming_message + plan.overall_status = PlanStatus.completed + await memory_store.update_plan(plan) + return True + except Exception as e: + logger.exception( + "Failed to handle human clarification -> agent message: %s", e + ) + return False + + @staticmethod + async def handle_human_clarification( + human_feedback: messages.UserClarificationResponse, user_id: str + ) -> bool: + """ + Process a UserClarificationResponse coming from the client. + + Args: + human_feedback: messages.UserClarificationResponse (contains relevant message data) + user_id: authenticated user id + + Returns: + dict with status and metadata + + Raises: + ValueError on invalid state + """ + try: + agent_msg = build_agent_message_from_user_clarification( + human_feedback, user_id + ) + + # Persist if your database layer supports it. + # Look for or implement something like: memory_store.add_agent_message(agent_msg) + memory_store = await DatabaseFactory.get_database(user_id=user_id) + await memory_store.add_agent_message(agent_msg) + + return True + except Exception as e: + logger.exception( + "Failed to handle human clarification -> agent message: %s", e + ) + return False diff --git a/src/backend/af/common/services/team_service.py b/src/backend/af/common/services/team_service.py new file mode 100644 index 000000000..02b9cdc2a --- /dev/null +++ b/src/backend/af/common/services/team_service.py @@ -0,0 +1,581 @@ +import logging +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Tuple + +from azure.core.exceptions import ( + ClientAuthenticationError, + HttpResponseError, + ResourceNotFoundError, +) +from azure.search.documents.indexes import SearchIndexClient +from common.config.app_config import config +from common.database.database_base import DatabaseBase +from common.models.messages_kernel import ( + StartingTask, + TeamAgent, + TeamConfiguration, + UserCurrentTeam, +) +from v3.common.services.foundry_service import FoundryService + + +class TeamService: + """Service for handling JSON team configuration operations.""" + + def __init__(self, memory_context: Optional[DatabaseBase] = None): + """Initialize with optional memory context.""" + self.memory_context = memory_context + self.logger = logging.getLogger(__name__) + + # Search validation configuration + self.search_endpoint = config.AZURE_SEARCH_ENDPOINT + + self.search_credential = config.get_azure_credentials() + + async def validate_and_parse_team_config( + self, json_data: Dict[str, Any], user_id: str + ) -> TeamConfiguration: + """ + Validate and parse team configuration JSON. + + Args: + json_data: Raw JSON data + user_id: User ID who uploaded the configuration + + Returns: + TeamConfiguration object + + Raises: + ValueError: If JSON structure is invalid + """ + try: + # Validate required top-level fields (id and team_id will be generated) + required_fields = [ + "name", + "status", + ] + for field in required_fields: + if field not in json_data: + raise ValueError(f"Missing required field: {field}") + + # Generate unique IDs and timestamps + unique_team_id = str(uuid.uuid4()) + session_id = str(uuid.uuid4()) + current_timestamp = datetime.now(timezone.utc).isoformat() + + # Validate agents array exists and is not empty + if "agents" not in json_data or not isinstance(json_data["agents"], list): + raise ValueError( + "Missing or invalid 'agents' field - must be a non-empty array" + ) + + if len(json_data["agents"]) == 0: + raise ValueError("Agents array cannot be empty") + + # Validate starting_tasks array exists and is not empty + if "starting_tasks" not in json_data or not isinstance( + json_data["starting_tasks"], list + ): + raise ValueError( + "Missing or invalid 'starting_tasks' field - must be a non-empty array" + ) + + if len(json_data["starting_tasks"]) == 0: + raise ValueError("Starting tasks array cannot be empty") + + # Parse agents + agents = [] + for agent_data in json_data["agents"]: + agent = self._validate_and_parse_agent(agent_data) + agents.append(agent) + + # Parse starting tasks + starting_tasks = [] + for task_data in json_data["starting_tasks"]: + task = self._validate_and_parse_task(task_data) + starting_tasks.append(task) + + # Create team configuration + team_config = TeamConfiguration( + id=unique_team_id, # Use generated GUID + session_id=session_id, + team_id=unique_team_id, # Use generated GUID + name=json_data["name"], + status=json_data["status"], + created=current_timestamp, # Use generated timestamp + created_by=user_id, # Use user_id who uploaded the config + agents=agents, + description=json_data.get("description", ""), + logo=json_data.get("logo", ""), + plan=json_data.get("plan", ""), + starting_tasks=starting_tasks, + user_id=user_id, + ) + + self.logger.info( + "Successfully validated team configuration: %s (ID: %s)", + team_config.team_id, + team_config.id, + ) + return team_config + + except Exception as e: + self.logger.error("Error validating team configuration: %s", str(e)) + raise ValueError(f"Invalid team configuration: {str(e)}") from e + + def _validate_and_parse_agent(self, agent_data: Dict[str, Any]) -> TeamAgent: + """Validate and parse a single agent.""" + required_fields = ["input_key", "type", "name", "icon"] + for field in required_fields: + if field not in agent_data: + raise ValueError(f"Agent missing required field: {field}") + + return TeamAgent( + input_key=agent_data["input_key"], + type=agent_data["type"], + name=agent_data["name"], + deployment_name=agent_data.get("deployment_name", ""), + icon=agent_data["icon"], + system_message=agent_data.get("system_message", ""), + description=agent_data.get("description", ""), + use_rag=agent_data.get("use_rag", False), + use_mcp=agent_data.get("use_mcp", False), + use_bing=agent_data.get("use_bing", False), + use_reasoning=agent_data.get("use_reasoning", False), + index_name=agent_data.get("index_name", ""), + coding_tools=agent_data.get("coding_tools", False), + ) + + def _validate_and_parse_task(self, task_data: Dict[str, Any]) -> StartingTask: + """Validate and parse a single starting task.""" + required_fields = ["id", "name", "prompt", "created", "creator", "logo"] + for field in required_fields: + if field not in task_data: + raise ValueError(f"Starting task missing required field: {field}") + + return StartingTask( + id=task_data["id"], + name=task_data["name"], + prompt=task_data["prompt"], + created=task_data["created"], + creator=task_data["creator"], + logo=task_data["logo"], + ) + + async def save_team_configuration(self, team_config: TeamConfiguration) -> str: + """ + Save team configuration to the database. + + Args: + team_config: TeamConfiguration object to save + + Returns: + The unique ID of the saved configuration + """ + try: + # Use the specific add_team method from cosmos memory context + await self.memory_context.add_team(team_config) + + self.logger.info( + "Successfully saved team configuration with ID: %s", team_config.id + ) + return team_config.id + + except Exception as e: + self.logger.error("Error saving team configuration: %s", str(e)) + raise ValueError(f"Failed to save team configuration: {str(e)}") from e + + async def get_team_configuration( + self, team_id: str, user_id: str + ) -> Optional[TeamConfiguration]: + """ + Retrieve a team configuration by ID. + + Args: + team_id: Configuration ID to retrieve + user_id: User ID for access control + + Returns: + TeamConfiguration object or None if not found + """ + try: + # Get the specific configuration using the team-specific method + team_config = await self.memory_context.get_team(team_id) + + if team_config is None: + return None + + # Verify the configuration belongs to the user + # if team_config.user_id != user_id: + # self.logger.warning( + # "Access denied: config %s does not belong to user %s", + # team_id, + # user_id, + # ) + # return None + + return team_config + + except (KeyError, TypeError, ValueError) as e: + self.logger.error("Error retrieving team configuration: %s", str(e)) + return None + + async def delete_user_current_team(self, user_id: str) -> bool: + """ + Delete the current team for a user. + + Args: + user_id: User ID to delete the current team for + + Returns: + True if successful, False otherwise + """ + try: + await self.memory_context.delete_current_team(user_id) + self.logger.info("Successfully deleted current team for user %s", user_id) + return True + + except Exception as e: + self.logger.error("Error deleting current team: %s", str(e)) + return False + + async def handle_team_selection( + self, user_id: str, team_id: str + ) -> UserCurrentTeam: + """ + Set a default team for a user. + + Args: + user_id: User ID to set the default team for + team_id: Team ID to set as default + + Returns: + True if successful, False otherwise + """ + print("Handling team selection for user:", user_id, "team:", team_id) + try: + await self.memory_context.delete_current_team(user_id) + current_team = UserCurrentTeam( + user_id=user_id, + team_id=team_id, + ) + await self.memory_context.set_current_team(current_team) + return current_team + + except Exception as e: + self.logger.error("Error setting default team: %s", str(e)) + return None + + async def get_all_team_configurations(self) -> List[TeamConfiguration]: + """ + Retrieve all team configurations for a user. + + Args: + user_id: User ID to retrieve configurations for + + Returns: + List of TeamConfiguration objects + """ + try: + # Use the specific get_all_teams method + team_configs = await self.memory_context.get_all_teams() + return team_configs + + except (KeyError, TypeError, ValueError) as e: + self.logger.error("Error retrieving team configurations: %s", str(e)) + return [] + + async def delete_team_configuration(self, team_id: str, user_id: str) -> bool: + """ + Delete a team configuration by ID. + + Args: + team_id: Configuration ID to delete + user_id: User ID for access control + + Returns: + True if deleted successfully, False if not found + """ + try: + # First, verify the configuration exists and belongs to the user + success = await self.memory_context.delete_team(team_id) + if success: + self.logger.info("Successfully deleted team configuration: %s", team_id) + + return success + + except (KeyError, TypeError, ValueError) as e: + self.logger.error("Error deleting team configuration: %s", str(e)) + return False + + def extract_models_from_agent(self, agent: Dict[str, Any]) -> set: + """ + Extract all possible model references from a single agent configuration. + Skip proxy agents as they don't require deployment models. + """ + models = set() + + # Skip proxy agents - they don't need deployment models + if agent.get("name", "").lower() == "proxyagent": + return models + + if agent.get("deployment_name"): + models.add(str(agent["deployment_name"]).lower()) + + if agent.get("model"): + models.add(str(agent["model"]).lower()) + + config = agent.get("config", {}) + if isinstance(config, dict): + for field in ["model", "deployment_name", "engine"]: + if config.get(field): + models.add(str(config[field]).lower()) + + instructions = agent.get("instructions", "") or agent.get("system_message", "") + if instructions: + models.update(self.extract_models_from_text(str(instructions))) + + return models + + def extract_models_from_text(self, text: str) -> set: + """Extract model names from text using pattern matching.""" + import re + + models = set() + text_lower = text.lower() + model_patterns = [ + r"gpt-4o(?:-\w+)?", + r"gpt-4(?:-\w+)?", + r"gpt-35-turbo(?:-\w+)?", + r"gpt-3\.5-turbo(?:-\w+)?", + r"claude-3(?:-\w+)?", + r"claude-2(?:-\w+)?", + r"gemini-pro(?:-\w+)?", + r"mistral-\w+", + r"llama-?\d+(?:-\w+)?", + r"text-davinci-\d+", + r"text-embedding-\w+", + r"ada-\d+", + r"babbage-\d+", + r"curie-\d+", + r"davinci-\d+", + ] + + for pattern in model_patterns: + matches = re.findall(pattern, text_lower) + models.update(matches) + + return models + + async def validate_team_models( + self, team_config: Dict[str, Any] + ) -> Tuple[bool, List[str]]: + """Validate that all models required by agents in the team config are deployed.""" + try: + foundry_service = FoundryService() + deployments = await foundry_service.list_model_deployments() + available_models = [ + d.get("name", "").lower() + for d in deployments + if d.get("status") == "Succeeded" + ] + + required_models: set = set() + agents = team_config.get("agents", []) + for agent in agents: + if isinstance(agent, dict): + required_models.update(self.extract_models_from_agent(agent)) + + team_level_models = self.extract_team_level_models(team_config) + required_models.update(team_level_models) + + if not required_models: + default_model = config.AZURE_OPENAI_DEPLOYMENT_NAME + required_models.add(default_model.lower()) + + missing_models: List[str] = [] + for model in required_models: + # Temporary bypass for known deployed models + if model.lower() in ["gpt-4o", "o3", "gpt-4", "gpt-35-turbo"]: + continue + if model not in available_models: + missing_models.append(model) + + is_valid = len(missing_models) == 0 + if not is_valid: + self.logger.warning(f"Missing model deployments: {missing_models}") + self.logger.info(f"Available deployments: {available_models}") + return is_valid, missing_models + except Exception as e: + self.logger.error(f"Error validating team models: {e}") + return True, [] + + async def get_deployment_status_summary(self) -> Dict[str, Any]: + """Get a summary of deployment status for debugging/monitoring.""" + try: + foundry_service = FoundryService() + deployments = await foundry_service.list_model_deployments() + summary: Dict[str, Any] = { + "total_deployments": len(deployments), + "successful_deployments": [], + "failed_deployments": [], + "pending_deployments": [], + } + for deployment in deployments: + name = deployment.get("name", "unknown") + status = deployment.get("status", "unknown") + if status == "Succeeded": + summary["successful_deployments"].append(name) + elif status in ["Failed", "Canceled"]: + summary["failed_deployments"].append(name) + else: + summary["pending_deployments"].append(name) + return summary + except Exception as e: + self.logger.error(f"Error getting deployment summary: {e}") + return {"error": str(e)} + + def extract_team_level_models(self, team_config: Dict[str, Any]) -> set: + """Extract model references from team-level configuration.""" + models = set() + for field in ["default_model", "model", "llm_model"]: + if team_config.get(field): + models.add(str(team_config[field]).lower()) + settings = team_config.get("settings", {}) + if isinstance(settings, dict): + for field in ["model", "deployment_name"]: + if settings.get(field): + models.add(str(settings[field]).lower()) + env_config = team_config.get("environment", {}) + if isinstance(env_config, dict): + for field in ["model", "openai_deployment"]: + if env_config.get(field): + models.add(str(env_config[field]).lower()) + return models + + # ----------------------- + # Search validation methods + # ----------------------- + + async def validate_team_search_indexes( + self, team_config: Dict[str, Any] + ) -> Tuple[bool, List[str]]: + """ + Validate that all search indexes referenced in the team config exist. + Only validates if there are actually search indexes/RAG agents in the config. + """ + try: + index_names = self.extract_index_names(team_config) + has_rag_agents = self.has_rag_or_search_agents(team_config) + + if not index_names and not has_rag_agents: + self.logger.info( + "No search indexes or RAG agents found in team config - skipping search validation" + ) + return True, [] + + if not self.search_endpoint: + if index_names or has_rag_agents: + error_msg = "Team configuration references search indexes but no Azure Search endpoint is configured" + self.logger.warning(error_msg) + return False, [error_msg] + else: + return True, [] + + if not index_names: + self.logger.info( + "RAG agents found but no specific search indexes specified" + ) + return True, [] + + validation_errors: List[str] = [] + unique_indexes = set(index_names) + self.logger.info( + f"Validating {len(unique_indexes)} search indexes: {list(unique_indexes)}" + ) + for index_name in unique_indexes: + is_valid, error_message = await self.validate_single_index(index_name) + if not is_valid: + validation_errors.append(error_message) + return len(validation_errors) == 0, validation_errors + except Exception as e: + self.logger.error(f"Error validating search indexes: {str(e)}") + return False, [f"Search index validation error: {str(e)}"] + + def extract_index_names(self, team_config: Dict[str, Any]) -> List[str]: + """Extract all index names from RAG agents in the team configuration.""" + index_names: List[str] = [] + agents = team_config.get("agents", []) + for agent in agents: + if isinstance(agent, dict): + agent_type = str(agent.get("type", "")).strip().lower() + if agent_type == "rag": + index_name = agent.get("index_name") + if index_name and str(index_name).strip(): + index_names.append(str(index_name).strip()) + return list(set(index_names)) + + def has_rag_or_search_agents(self, team_config: Dict[str, Any]) -> bool: + """Check if the team configuration contains RAG agents.""" + agents = team_config.get("agents", []) + for agent in agents: + if isinstance(agent, dict): + agent_type = str(agent.get("type", "")).strip().lower() + if agent_type == "rag": + return True + return False + + async def validate_single_index(self, index_name: str) -> Tuple[bool, str]: + """Validate that a single search index exists and is accessible.""" + try: + index_client = SearchIndexClient( + endpoint=self.search_endpoint, credential=self.search_credential + ) + index = index_client.get_index(index_name) + if index: + self.logger.info(f"Search index '{index_name}' found and accessible") + return True, "" + else: + error_msg = f"Search index '{index_name}' exists but may not be properly configured" + self.logger.warning(error_msg) + return False, error_msg + except ResourceNotFoundError: + error_msg = f"Search index '{index_name}' does not exist" + self.logger.error(error_msg) + return False, error_msg + except ClientAuthenticationError as e: + error_msg = ( + f"Authentication failed for search index '{index_name}': {str(e)}" + ) + self.logger.error(error_msg) + return False, error_msg + except HttpResponseError as e: + error_msg = f"Error accessing search index '{index_name}': {str(e)}" + self.logger.error(error_msg) + return False, error_msg + except Exception as e: + error_msg = ( + f"Unexpected error validating search index '{index_name}': {str(e)}" + ) + self.logger.error(error_msg) + return False, error_msg + + async def get_search_index_summary(self) -> Dict[str, Any]: + """Get a summary of available search indexes for debugging/monitoring.""" + try: + if not self.search_endpoint: + return {"error": "No Azure Search endpoint configured"} + index_client = SearchIndexClient( + endpoint=self.search_endpoint, credential=self.search_credential + ) + indexes = list(index_client.list_indexes()) + summary = { + "search_endpoint": self.search_endpoint, + "total_indexes": len(indexes), + "available_indexes": [index.name for index in indexes], + } + return summary + except Exception as e: + self.logger.error(f"Error getting search index summary: {e}") + return {"error": str(e)} diff --git a/src/backend/af/config/__init__.py b/src/backend/af/config/__init__.py new file mode 100644 index 000000000..558f942fb --- /dev/null +++ b/src/backend/af/config/__init__.py @@ -0,0 +1 @@ +# Configuration package for Magentic Example diff --git a/src/backend/af/config/agent_registry.py b/src/backend/af/config/agent_registry.py new file mode 100644 index 000000000..564beb136 --- /dev/null +++ b/src/backend/af/config/agent_registry.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft. All rights reserved. +"""Global agent registry for tracking and managing agent lifecycles across the application.""" + +import asyncio +import logging +import threading +from typing import List, Dict, Any, Optional +from weakref import WeakSet + + +class AgentRegistry: + """Global registry for tracking and managing all agent instances across the application.""" + + def __init__(self): + self.logger = logging.getLogger(__name__) + self._lock = threading.Lock() + self._all_agents: WeakSet = WeakSet() + self._agent_metadata: Dict[int, Dict[str, Any]] = {} + + def register_agent(self, agent: Any, user_id: Optional[str] = None) -> None: + """Register an agent instance for tracking and lifecycle management.""" + with self._lock: + try: + self._all_agents.add(agent) + agent_id = id(agent) + self._agent_metadata[agent_id] = { + 'type': type(agent).__name__, + 'user_id': user_id, + 'name': getattr(agent, 'agent_name', getattr(agent, 'name', 'Unknown')) + } + self.logger.info(f"Registered agent: {type(agent).__name__} (ID: {agent_id}, User: {user_id})") + except Exception as e: + self.logger.error(f"Failed to register agent: {e}") + + def unregister_agent(self, agent: Any) -> None: + """Unregister an agent instance.""" + with self._lock: + try: + agent_id = id(agent) + self._all_agents.discard(agent) + if agent_id in self._agent_metadata: + metadata = self._agent_metadata.pop(agent_id) + self.logger.info(f"Unregistered agent: {metadata.get('type', 'Unknown')} (ID: {agent_id})") + except Exception as e: + self.logger.error(f"Failed to unregister agent: {e}") + + def get_all_agents(self) -> List[Any]: + """Get all currently registered agents.""" + with self._lock: + return list(self._all_agents) + + def get_agent_count(self) -> int: + """Get the total number of registered agents.""" + with self._lock: + return len(self._all_agents) + + async def cleanup_all_agents(self) -> None: + """Clean up all registered agents across all users.""" + all_agents = self.get_all_agents() + + if not all_agents: + self.logger.info("No agents to clean up") + return + + self.logger.info(f"🧹 Starting cleanup of {len(all_agents)} total agents") + + # Log agent details for debugging + for i, agent in enumerate(all_agents): + agent_name = getattr(agent, 'agent_name', getattr(agent, 'name', type(agent).__name__)) + agent_type = type(agent).__name__ + has_close = hasattr(agent, 'close') + self.logger.info(f"Agent {i + 1}: {agent_name} (Type: {agent_type}, Has close(): {has_close})") + + # Clean up agents concurrently + cleanup_tasks = [] + for agent in all_agents: + if hasattr(agent, 'close'): + cleanup_tasks.append(self._safe_close_agent(agent)) + else: + agent_name = getattr(agent, 'agent_name', getattr(agent, 'name', type(agent).__name__)) + self.logger.warning(f"⚠️ Agent {agent_name} has no close() method - just unregistering from registry") + self.unregister_agent(agent) + + if cleanup_tasks: + self.logger.info(f"πŸ”„ Executing {len(cleanup_tasks)} cleanup tasks...") + results = await asyncio.gather(*cleanup_tasks, return_exceptions=True) + + # Log any exceptions that occurred during cleanup + success_count = 0 + for i, result in enumerate(results): + if isinstance(result, Exception): + self.logger.error(f"❌ Error cleaning up agent {i}: {result}") + else: + success_count += 1 + + self.logger.info(f"βœ… Successfully cleaned up {success_count}/{len(cleanup_tasks)} agents") + + # Clear all tracking + with self._lock: + self._all_agents.clear() + self._agent_metadata.clear() + + self.logger.info("πŸŽ‰ Completed cleanup of all agents") + + async def _safe_close_agent(self, agent: Any) -> None: + """Safely close an agent with error handling.""" + try: + agent_name = getattr(agent, 'agent_name', getattr(agent, 'name', type(agent).__name__)) + self.logger.info(f"Closing agent: {agent_name}") + + # Call the agent's close method - it should handle Azure deletion and registry cleanup + if asyncio.iscoroutinefunction(agent.close): + await agent.close() + else: + agent.close() + + self.logger.info(f"Successfully closed agent: {agent_name}") + + except Exception as e: + agent_name = getattr(agent, 'agent_name', getattr(agent, 'name', type(agent).__name__)) + self.logger.error(f"Failed to close agent {agent_name}: {e}") + + def get_registry_status(self) -> Dict[str, Any]: + """Get current status of the agent registry for debugging and monitoring.""" + with self._lock: + status = { + 'total_agents': len(self._all_agents), + 'agent_types': {} + } + + # Count agents by type + for agent in self._all_agents: + agent_type = type(agent).__name__ + status['agent_types'][agent_type] = status['agent_types'].get(agent_type, 0) + 1 + + return status + + +# Global registry instance +agent_registry = AgentRegistry() diff --git a/src/backend/af/config/settings.py b/src/backend/af/config/settings.py new file mode 100644 index 000000000..5958eb4f3 --- /dev/null +++ b/src/backend/af/config/settings.py @@ -0,0 +1,418 @@ +""" +Configuration settings for the Magentic Employee Onboarding system. +Handles Azure OpenAI, MCP, and environment setup. +""" + +import asyncio +import json +import logging +from typing import Dict, Optional + +from common.config.app_config import config +from common.models.messages_kernel import TeamConfiguration +from fastapi import WebSocket +from semantic_kernel.agents.orchestration.magentic import MagenticOrchestration +from semantic_kernel.connectors.ai.open_ai import ( + AzureChatCompletion, + OpenAIChatPromptExecutionSettings, +) +from v3.models.messages import MPlan, WebsocketMessageType + +logger = logging.getLogger(__name__) + + +class AzureConfig: + """Azure OpenAI and authentication configuration.""" + + def __init__(self): + self.endpoint = config.AZURE_OPENAI_ENDPOINT + self.reasoning_model = config.REASONING_MODEL_NAME + self.standard_model = config.AZURE_OPENAI_DEPLOYMENT_NAME + # self.bing_connection_name = config.AZURE_BING_CONNECTION_NAME + + # Create credential + self.credential = config.get_azure_credentials() + + def ad_token_provider(self) -> str: + token = self.credential.get_token(config.AZURE_COGNITIVE_SERVICES) + return token.token + + async def create_chat_completion_service(self, use_reasoning_model: bool = False): + """Create Azure Chat Completion service.""" + model_name = ( + self.reasoning_model if use_reasoning_model else self.standard_model + ) + # Create Azure Chat Completion service + return AzureChatCompletion( + deployment_name=model_name, + endpoint=self.endpoint, + ad_token_provider=self.ad_token_provider, + ) + + def create_execution_settings(self): + """Create execution settings for OpenAI.""" + return OpenAIChatPromptExecutionSettings(max_tokens=4000, temperature=0.1) + + +class MCPConfig: + """MCP server configuration.""" + + def __init__(self): + self.url = config.MCP_SERVER_ENDPOINT + self.name = config.MCP_SERVER_NAME + self.description = config.MCP_SERVER_DESCRIPTION + + def get_headers(self, token: str): + """Get MCP headers with authentication token.""" + return ( + {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + if token + else {} + ) + + +class OrchestrationConfig: + """Configuration for orchestration settings.""" + + def __init__(self): + self.orchestrations: Dict[str, MagenticOrchestration] = ( + {} + ) # user_id -> orchestration instance + self.plans: Dict[str, MPlan] = {} # plan_id -> plan details + self.approvals: Dict[str, bool] = {} # m_plan_id -> approval status + self.sockets: Dict[str, WebSocket] = {} # user_id -> WebSocket + self.clarifications: Dict[str, str] = {} # m_plan_id -> clarification response + self.max_rounds: int = ( + 20 # Maximum number of replanning rounds 20 needed to accommodate complex tasks + ) + + # Event-driven notification system for approvals and clarifications + self._approval_events: Dict[str, asyncio.Event] = {} + self._clarification_events: Dict[str, asyncio.Event] = {} + + # Default timeout for waiting operations (5 minutes) + self.default_timeout: float = 300.0 + + def get_current_orchestration(self, user_id: str) -> MagenticOrchestration: + """get existing orchestration instance.""" + return self.orchestrations.get(user_id, None) + + def set_approval_pending(self, plan_id: str) -> None: + """Set an approval as pending and create an event for it.""" + self.approvals[plan_id] = None + if plan_id not in self._approval_events: + self._approval_events[plan_id] = asyncio.Event() + else: + # Clear existing event to reset state + self._approval_events[plan_id].clear() + + def set_approval_result(self, plan_id: str, approved: bool) -> None: + """Set the approval result and trigger the event.""" + self.approvals[plan_id] = approved + if plan_id in self._approval_events: + self._approval_events[plan_id].set() + + async def wait_for_approval(self, plan_id: str, timeout: Optional[float] = None) -> bool: + """ + Wait for an approval decision with timeout. + + Args: + plan_id: The plan ID to wait for + timeout: Timeout in seconds (defaults to default_timeout) + + Returns: + The approval decision (True/False) + + Raises: + asyncio.TimeoutError: If timeout is exceeded + KeyError: If plan_id is not found in approvals + """ + if timeout is None: + timeout = self.default_timeout + + if plan_id not in self.approvals: + raise KeyError(f"Plan ID {plan_id} not found in approvals") + + if self.approvals[plan_id] is not None: + # Already has a result + return self.approvals[plan_id] + + if plan_id not in self._approval_events: + self._approval_events[plan_id] = asyncio.Event() + + try: + await asyncio.wait_for(self._approval_events[plan_id].wait(), timeout=timeout) + return self.approvals[plan_id] + except asyncio.TimeoutError: + # Clean up on timeout + self.cleanup_approval(plan_id) + raise + except asyncio.CancelledError: + # Handle task cancellation gracefully + logger.debug(f"Approval request {plan_id} was cancelled") + raise + except Exception as e: + # Handle any other unexpected errors + logger.error(f"Unexpected error waiting for approval {plan_id}: {e}") + raise + finally: + # Ensure cleanup happens regardless of how the try block exits + # Only cleanup if the approval is still pending (None) to avoid + # cleaning up successful approvals + if plan_id in self.approvals and self.approvals[plan_id] is None: + self.cleanup_approval(plan_id) + + def set_clarification_pending(self, request_id: str) -> None: + """Set a clarification as pending and create an event for it.""" + self.clarifications[request_id] = None + if request_id not in self._clarification_events: + self._clarification_events[request_id] = asyncio.Event() + else: + # Clear existing event to reset state + self._clarification_events[request_id].clear() + + def set_clarification_result(self, request_id: str, answer: str) -> None: + """Set the clarification response and trigger the event.""" + self.clarifications[request_id] = answer + if request_id in self._clarification_events: + self._clarification_events[request_id].set() + + async def wait_for_clarification(self, request_id: str, timeout: Optional[float] = None) -> str: + """ + Wait for a clarification response with timeout. + + Args: + request_id: The request ID to wait for + timeout: Timeout in seconds (defaults to default_timeout) + + Returns: + The clarification response + + Raises: + asyncio.TimeoutError: If timeout is exceeded + KeyError: If request_id is not found in clarifications + """ + if timeout is None: + timeout = self.default_timeout + + if request_id not in self.clarifications: + raise KeyError(f"Request ID {request_id} not found in clarifications") + + if self.clarifications[request_id] is not None: + # Already has a result + return self.clarifications[request_id] + + if request_id not in self._clarification_events: + self._clarification_events[request_id] = asyncio.Event() + + try: + await asyncio.wait_for(self._clarification_events[request_id].wait(), timeout=timeout) + return self.clarifications[request_id] + except asyncio.TimeoutError: + # Clean up on timeout + self.cleanup_clarification(request_id) + raise + except asyncio.CancelledError: + # Handle task cancellation gracefully + logger.debug(f"Clarification request {request_id} was cancelled") + raise + except Exception as e: + # Handle any other unexpected errors + logger.error(f"Unexpected error waiting for clarification {request_id}: {e}") + raise + finally: + # Ensure cleanup happens regardless of how the try block exits + # Only cleanup if the clarification is still pending (None) to avoid + # cleaning up successful clarifications + if request_id in self.clarifications and self.clarifications[request_id] is None: + self.cleanup_clarification(request_id) + + def cleanup_approval(self, plan_id: str) -> None: + """Clean up approval resources.""" + self.approvals.pop(plan_id, None) + if plan_id in self._approval_events: + del self._approval_events[plan_id] + + def cleanup_clarification(self, request_id: str) -> None: + """Clean up clarification resources.""" + self.clarifications.pop(request_id, None) + if request_id in self._clarification_events: + del self._clarification_events[request_id] + + +class ConnectionConfig: + """Connection manager for WebSocket connections.""" + + def __init__(self): + self.connections: Dict[str, WebSocket] = {} + # Map user_id to process_id for context-based messaging + self.user_to_process: Dict[str, str] = {} + + def add_connection( + self, process_id: str, connection: WebSocket, user_id: str = None + ): + """Add a new connection.""" + # Close existing connection if it exists + if process_id in self.connections: + try: + asyncio.create_task(self.connections[process_id].close()) + except Exception as e: + logger.error( + f"Error closing existing connection for user {process_id}: {e}" + ) + + self.connections[process_id] = connection + # Map user to process for context-based messaging + if user_id: + user_id = str(user_id) + # If this user already has a different process mapped, close that old connection + old_process_id = self.user_to_process.get(user_id) + if old_process_id and old_process_id != process_id: + old_connection = self.connections.get(old_process_id) + if old_connection: + try: + asyncio.create_task(old_connection.close()) + del self.connections[old_process_id] + logger.info( + f"Closed old connection {old_process_id} for user {user_id}" + ) + except Exception as e: + logger.error( + f"Error closing old connection for user {user_id}: {e}" + ) + + self.user_to_process[user_id] = process_id + logger.info( + f"WebSocket connection added for process: {process_id} (user: {user_id})" + ) + else: + logger.info(f"WebSocket connection added for process: {process_id}") + + def remove_connection(self, process_id): + """Remove a connection.""" + process_id = str(process_id) + if process_id in self.connections: + del self.connections[process_id] + + # Remove from user mapping if exists + for user_id, mapped_process_id in list(self.user_to_process.items()): + if mapped_process_id == process_id: + del self.user_to_process[user_id] + logger.debug(f"Removed user mapping: {user_id} -> {process_id}") + break + + def get_connection(self, process_id): + """Get a connection.""" + return self.connections.get(process_id) + + async def close_connection(self, process_id): + """Remove a connection.""" + connection = self.get_connection(process_id) + if connection: + try: + await connection.close() + logger.info("Connection closed for batch ID: %s", process_id) + except Exception as e: + logger.error(f"Error closing connection for {process_id}: {e}") + else: + logger.warning("No connection found for batch ID: %s", process_id) + + # Always remove from connections dict + self.remove_connection(process_id) + logger.info("Connection removed for batch ID: %s", process_id) + + async def send_status_update_async( + self, + message: any, + user_id: str, + message_type: WebsocketMessageType = WebsocketMessageType.SYSTEM_MESSAGE, + ): + """Send a status update to a specific client.""" + + if not user_id: + logger.warning("No user_id available for WebSocket message") + return + + process_id = self.user_to_process.get(user_id) + if not process_id: + logger.warning("No active WebSocket process found for user ID: %s", user_id) + logger.debug( + f"Available user mappings: {list(self.user_to_process.keys())}" + ) + return + + # Convert message to proper format for frontend + try: + if hasattr(message, "to_dict"): + # Use the custom to_dict method if available + message_data = message.to_dict() + elif hasattr(message, "data") and hasattr(message, "type"): + # Handle structured messages with data property + message_data = message.data + elif isinstance(message, dict): + # Already a dictionary + message_data = message + else: + # Convert to string if it's a simple type + message_data = str(message) + except Exception as e: + logger.error("Error processing message data: %s", e) + message_data = str(message) + + standard_message = {"type": message_type, "data": message_data} + connection = self.get_connection(process_id) + if connection: + try: + str_message = json.dumps(standard_message, default=str) + await connection.send_text(str_message) + logger.debug(f"Message sent to user {user_id} via process {process_id}") + except Exception as e: + logger.error(f"Failed to send message to user {user_id}: {e}") + # Clean up stale connection + self.remove_connection(process_id) + else: + logger.warning( + "No connection found for process ID: %s (user: %s)", process_id, user_id + ) + # Clean up stale mapping + if user_id in self.user_to_process: + del self.user_to_process[user_id] + + def send_status_update(self, message: str, process_id: str): + """Send a status update to a specific client (sync wrapper).""" + process_id = str(process_id) + connection = self.get_connection(process_id) + if connection: + try: + # Use asyncio.create_task instead of run_coroutine_threadsafe + asyncio.create_task(connection.send_text(message)) + except Exception as e: + logger.error(f"Failed to send message to process {process_id}: {e}") + else: + logger.warning("No connection found for process ID: %s", process_id) + + +class TeamConfig: + """Team configuration for agents.""" + + def __init__(self): + self.teams: Dict[str, TeamConfiguration] = {} + + def set_current_team(self, user_id: str, team_configuration: TeamConfiguration): + """Add a new team configuration.""" + + # To do: close current team of agents if any + + self.teams[user_id] = team_configuration + + def get_current_team(self, user_id: str) -> TeamConfiguration: + """Get the current team configuration.""" + return self.teams.get(user_id, None) + + +# Global config instances +azure_config = AzureConfig() +mcp_config = MCPConfig() +orchestration_config = OrchestrationConfig() +connection_config = ConnectionConfig() +team_config = TeamConfig() diff --git a/src/backend/af/magentic_agents/foundry_agent.py b/src/backend/af/magentic_agents/foundry_agent.py new file mode 100644 index 000000000..30746e847 --- /dev/null +++ b/src/backend/af/magentic_agents/foundry_agent.py @@ -0,0 +1,294 @@ +"""Agent template for building foundry agents with Azure AI Search, Bing, and MCP plugins (agent_framework version).""" + +import logging +from typing import List, Optional + +from azure.ai.agents.models import Agent, AzureAISearchTool, CodeInterpreterToolDefinition +from agent_framework.azure import AzureAIAgentClient +from agent_framework import ChatMessage, Role, ChatOptions, HostedMCPTool # HostedMCPTool for MCP plugin mapping + +from v3.magentic_agents.common.lifecycle import AzureAgentBase +from v3.magentic_agents.models.agent_models import MCPConfig, SearchConfig +from v3.config.agent_registry import agent_registry + +# exception too broad warning +# pylint: disable=w0718 + + +class FoundryAgentTemplate(AzureAgentBase): + """Agent that uses Azure AI Search (RAG) and optional MCP tools via agent_framework.""" + + def __init__( + self, + agent_name: str, + agent_description: str, + agent_instructions: str, + model_deployment_name: str, + enable_code_interpreter: bool = False, + mcp_config: MCPConfig | None = None, + # bing_config: BingConfig | None = None, + search_config: SearchConfig | None = None, + ) -> None: + super().__init__(mcp=mcp_config) + self.agent_name = agent_name + self.agent_description = agent_description + self.agent_instructions = agent_instructions + self.model_deployment_name = model_deployment_name + self.enable_code_interpreter = enable_code_interpreter + # self.bing = bing_config + self.mcp = mcp_config + self.search = search_config + self._search_connection = None + self._bing_connection = None + self.logger = logging.getLogger(__name__) + + if self.model_deployment_name in ["o3", "o4-mini"]: + raise ValueError( + "The current version of Foundry agents does not support reasoning models." + ) + + async def _make_azure_search_tool(self) -> Optional[AzureAISearchTool]: + """Create Azure AI Search tool for RAG capabilities.""" + if not all([self.client, self.search and self.search.connection_name, self.search and self.search.index_name]): + self.logger.info("Azure AI Search tool not enabled") + return None + + try: + self._search_connection = await self.client.connections.get( + name=self.search.connection_name + ) + self.logger.info("Found Azure AI Search connection: %s", self._search_connection.id) + + search_tool = AzureAISearchTool( + index_connection_id=self._search_connection.id, + index_name=self.search.index_name, + ) + self.logger.info("Azure AI Search tool created for index: %s", self.search.index_name) + return search_tool + + except Exception as ex: + self.logger.error( + "Azure AI Search tool creation failed: %s | Connection name: %s | Index name: %s | " + "Ensure the connection exists in Azure AI Foundry portal.", + ex, + getattr(self.search, "connection_name", None), + getattr(self.search, "index_name", None), + ) + return None + + async def _collect_tools_and_resources(self) -> tuple[List, dict]: + """Collect all available tools and tool_resources to embed in persistent agent definition.""" + tools: List = [] + tool_resources: dict = {} + + if self.search and self.search.connection_name and self.search.index_name: + search_tool = await self._make_azure_search_tool() + if search_tool: + tools.extend(search_tool.definitions) + tool_resources = search_tool.resources + self.logger.info( + "Added Azure AI Search tools: %d tool definitions", len(search_tool.definitions) + ) + else: + self.logger.error("Azure AI Search tool not configured properly") + + if self.enable_code_interpreter: + try: + tools.append(CodeInterpreterToolDefinition()) + self.logger.info("Added Code Interpreter tool") + except ImportError as ie: + self.logger.error("Code Interpreter tool requires additional dependencies: %s", ie) + + self.logger.info("Total tools configured in definition: %d", len(tools)) + return tools, tool_resources + + async def _after_open(self) -> None: + """Build or reuse the Azure AI agent definition; create agent_framework client.""" + definition = await self._get_azure_ai_agent_definition(self.agent_name) + + if definition is not None: + connection_compatible = await self._check_connection_compatibility(definition) + if not connection_compatible: + await self.client.agents.delete_agent(definition.id) + self.logger.info( + "Existing agent '%s' used incompatible connection. Creating new definition.", + self.agent_name, + ) + definition = None + + if definition is None: + tools, tool_resources = await self._collect_tools_and_resources() + definition = await self.client.agents.create_agent( + model=self.model_deployment_name, + name=self.agent_name, + description=self.agent_description, + instructions=self.agent_instructions, + tools=tools, + tool_resources=tool_resources, + ) + + try: + # Wrap existing agent definition with agent_framework client (persistent agent mode) + self._agent = AzureAIAgentClient( + project_client=self.client, + agent_id=str(definition.id), + agent_name=self.agent_name, + thread_id=None, # created dynamically if omitted during invocation + ) + except Exception as ex: + self.logger.error("Failed to initialize AzureAIAgentClient: %s", ex) + raise + + # Register with global registry + try: + agent_registry.register_agent(self) + self.logger.info("πŸ“ Registered agent '%s' with global registry", self.agent_name) + except Exception as registry_error: + self.logger.warning( + "⚠️ Failed to register agent '%s' with registry: %s", self.agent_name, registry_error + ) + + async def fetch_run_details(self, thread_id: str, run_id: str) -> None: + """Fetch and log run details on failure for diagnostics.""" + try: + run = await self.client.agents.runs.get(thread=thread_id, run=run_id) + self.logger.error( + "Run failure details | status=%s | id=%s | last_error=%s | usage=%s", + getattr(run, "status", None), + run_id, + getattr(run, "last_error", None), + getattr(run, "usage", None), + ) + except Exception as ex: + self.logger.error("Could not fetch run details: %s", ex) + + async def _check_connection_compatibility(self, existing_definition: Agent) -> bool: + """Ensure existing agent definition's Azure AI Search connection matches current configuration.""" + try: + if not self.search or not self.search.connection_name: + self.logger.info("No search configuration provided; treating existing definition as compatible.") + return True + + if not getattr(existing_definition, "tool_resources", None): + self.logger.info("Existing definition lacks tool resources.") + return not self.search.connection_name + + azure_ai_search_resources = existing_definition.tool_resources.get("azure_ai_search", {}) + if not azure_ai_search_resources: + self.logger.info("Existing definition has no Azure AI Search resources.") + return False + + indexes = azure_ai_search_resources.get("indexes", []) + if not indexes: + self.logger.info("Existing definition search resources contain no indexes.") + return False + + existing_connection_id = indexes[0].get("index_connection_id") + if not existing_connection_id: + self.logger.info("Existing definition missing connection ID.") + return False + + current_connection = await self.client.connections.get(name=self.search.connection_name) + current_connection_id = current_connection.id + compatible = existing_connection_id == current_connection_id + + if compatible: + self.logger.info("Connection compatible: %s", existing_connection_id) + else: + self.logger.info( + "Connection mismatch: existing %s vs current %s", + existing_connection_id, + current_connection_id, + ) + return compatible + except Exception as ex: + self.logger.error("Error checking connection compatibility: %s", ex) + return False + + async def _get_azure_ai_agent_definition(self, agent_name: str) -> Agent | None: + """Retrieve an existing Azure AI Agent definition by name if present.""" + try: + agent_id = None + agent_list = self.client.agents.list_agents() + async for agent in agent_list: + if agent.name == agent_name: + agent_id = agent.id + break + if agent_id is not None: + self.logger.info("Found existing agent definition with ID %s", agent_id) + return await self.client.agents.get_agent(agent_id) + return None + except Exception as e: + if "ResourceNotFound" in str(e) or "404" in str(e): + self.logger.info("Agent '%s' not found; will create new definition.", agent_name) + else: + self.logger.warning( + "Unexpected error retrieving agent '%s': %s. Proceeding to create new definition.", + agent_name, + e, + ) + return None + + async def invoke(self, prompt: str): + """ + Stream model output for a prompt. + + Yields agent_framework ChatResponseUpdate objects: + - update.text for incremental text + - update.contents for tool calls / usage events + """ + if not hasattr(self, "_agent") or self._agent is None: + raise RuntimeError("Agent client not initialized; call open() first.") + + messages = [ChatMessage(role=Role.USER, text=prompt)] + + tools = [] + # Map MCP plugin (if any) to HostedMCPTool for runtime tool calling + if self.mcp_plugin: + # Minimal HostedMCPTool; advanced mapping (approval modes, categories) can be added later. + tools.append( + HostedMCPTool( + name=self.mcp_plugin.name, + server_label=self.mcp_plugin.name.replace(" ", "_"), + description=getattr(self.mcp_plugin, "description", ""), + ) + ) + + chat_options = ChatOptions( + model_id=self.model_deployment_name, + tools=tools if tools else None, + tool_choice="auto", + allow_multiple_tool_calls=True, + temperature=0.7, + ) + + async for update in self._agent.get_streaming_response( + messages=messages, + chat_options=chat_options, + instructions=self.agent_instructions, + ): + yield update + + +async def create_foundry_agent( + agent_name: str, + agent_description: str, + agent_instructions: str, + model_deployment_name: str, + mcp_config: MCPConfig, + # bing_config: BingConfig, + search_config: SearchConfig, +) -> FoundryAgentTemplate: + """Factory function to create and open a FoundryAgentTemplate (agent_framework version).""" + agent = FoundryAgentTemplate( + agent_name=agent_name, + agent_description=agent_description, + agent_instructions=agent_instructions, + model_deployment_name=model_deployment_name, + enable_code_interpreter=True, + mcp_config=mcp_config, + # bing_config=bing_config, + search_config=search_config, + ) + await agent.open() + return agent \ No newline at end of file diff --git a/src/backend/af/magentic_agents/magentic_agent_factory.py b/src/backend/af/magentic_agents/magentic_agent_factory.py new file mode 100644 index 000000000..ed74e89be --- /dev/null +++ b/src/backend/af/magentic_agents/magentic_agent_factory.py @@ -0,0 +1,196 @@ +# Copyright (c) Microsoft. All rights reserved. +"""Factory for creating and managing magentic agents from JSON configurations.""" + +import json +import logging +from types import SimpleNamespace +from typing import List, Union + +from common.config.app_config import config +from common.models.messages_kernel import TeamConfiguration +from v3.magentic_agents.foundry_agent import FoundryAgentTemplate +from v3.magentic_agents.models.agent_models import MCPConfig, SearchConfig + +# from v3.magentic_agents.models.agent_models import (BingConfig, MCPConfig, +# SearchConfig) +from v3.magentic_agents.proxy_agent import ProxyAgent +from v3.magentic_agents.reasoning_agent import ReasoningAgentTemplate + + +class UnsupportedModelError(Exception): + """Raised when an unsupported model is specified.""" + + +class InvalidConfigurationError(Exception): + """Raised when agent configuration is invalid.""" + + +class MagenticAgentFactory: + """Factory for creating and managing magentic agents from JSON configurations.""" + + def __init__(self): + self.logger = logging.getLogger(__name__) + self._agent_list: List = [] + + # @staticmethod + # def parse_team_config(file_path: Union[str, Path]) -> SimpleNamespace: + # """Parse JSON file into objects using SimpleNamespace.""" + # with open(file_path, 'r') as f: + # data = json.load(f) + # return json.loads(json.dumps(data), object_hook=lambda d: SimpleNamespace(**d)) + + async def create_agent_from_config(self, user_id: str, agent_obj: SimpleNamespace) -> Union[FoundryAgentTemplate, ReasoningAgentTemplate, ProxyAgent]: + """ + Create an agent from configuration object. + + Args: + user_id: User ID + agent_obj: Agent object from parsed JSON (SimpleNamespace) + team_model: Model name to determine which template to use + + Returns: + Configured agent instance + + Raises: + UnsupportedModelError: If model is not supported + InvalidConfigurationError: If configuration is invalid + """ + # Get model from agent config, team model, or environment + deployment_name = getattr(agent_obj, "deployment_name", None) + + if not deployment_name and agent_obj.name.lower() == "proxyagent": + self.logger.info("Creating ProxyAgent") + return ProxyAgent(user_id=user_id) + + # Validate supported models + supported_models = json.loads(config.SUPPORTED_MODELS) + + if deployment_name not in supported_models: + raise UnsupportedModelError( + f"Model '{deployment_name}' not supported. Supported: {supported_models}" + ) + + # Determine which template to use + use_reasoning = deployment_name.startswith("o") + + # Validate reasoning template constraints + if use_reasoning: + if getattr(agent_obj, "use_bing", False) or getattr( + agent_obj, "coding_tools", False + ): + raise InvalidConfigurationError( + f"ReasoningAgentTemplate cannot use Bing search or coding tools. " + f"Agent '{agent_obj.name}' has use_bing={getattr(agent_obj, 'use_bing', False)}, " + f"coding_tools={getattr(agent_obj, 'coding_tools', False)}" + ) + + # Only create configs for explicitly requested capabilities + search_config = ( + SearchConfig.from_env() if getattr(agent_obj, "use_rag", False) else None + ) + mcp_config = ( + MCPConfig.from_env() if getattr(agent_obj, "use_mcp", False) else None + ) + # bing_config = BingConfig.from_env() if getattr(agent_obj, 'use_bing', False) else None + + self.logger.info( + f"Creating agent '{agent_obj.name}' with model '{deployment_name}' " + f"(Template: {'Reasoning' if use_reasoning else 'Foundry'})" + ) + + # Create appropriate agent + if use_reasoning: + # Get reasoning specific configuration + azure_openai_endpoint = config.AZURE_OPENAI_ENDPOINT + + agent = ReasoningAgentTemplate( + agent_name=agent_obj.name, + agent_description=getattr(agent_obj, "description", ""), + agent_instructions=getattr(agent_obj, "system_message", ""), + model_deployment_name=deployment_name, + azure_openai_endpoint=azure_openai_endpoint, + search_config=search_config, + mcp_config=mcp_config, + ) + else: + agent = FoundryAgentTemplate( + agent_name=agent_obj.name, + agent_description=getattr(agent_obj, "description", ""), + agent_instructions=getattr(agent_obj, "system_message", ""), + model_deployment_name=deployment_name, + enable_code_interpreter=getattr(agent_obj, "coding_tools", False), + mcp_config=mcp_config, + # bing_config=bing_config, + search_config=search_config, + ) + + await agent.open() + self.logger.info( + f"Successfully created and initialized agent '{agent_obj.name}'" + ) + return agent + + async def get_agents(self, user_id: str, team_config_input: TeamConfiguration) -> List: + """ + Create and return a team of agents from JSON configuration. + + Args: + user_id: User ID + team_config_input: team configuration object from cosmos db + + Returns: + List of initialized agent instances + """ + # self.logger.info(f"Loading team configuration from: {file_path}") + + try: + + initalized_agents = [] + + for i, agent_cfg in enumerate(team_config_input.agents, 1): + try: + self.logger.info(f"Creating agent {i}/{len(team_config_input.agents)}: {agent_cfg.name}") + + agent = await self.create_agent_from_config(user_id, agent_cfg) + initalized_agents.append(agent) + self._agent_list.append(agent) # Keep track for cleanup + + self.logger.info( + f"βœ… Agent {i}/{len(team_config_input.agents)} created: {agent_cfg.name}" + ) + + except (UnsupportedModelError, InvalidConfigurationError) as e: + self.logger.warning(f"Skipped agent {agent_cfg.name}: {e}") + continue + except Exception as e: + self.logger.error(f"Failed to create agent {agent_cfg.name}: {e}") + continue + + self.logger.info( + f"Successfully created {len(initalized_agents)}/{len(team_config_input.agents)} agents for team '{team_config_input.name}'" + ) + return initalized_agents + + except Exception as e: + self.logger.error(f"Failed to load team configuration: {e}") + raise + + @classmethod + async def cleanup_all_agents(cls, agent_list: List): + """Clean up all created agents.""" + cls.logger = logging.getLogger(__name__) + cls.logger.info(f"Cleaning up {len(agent_list)} agents") + + for agent in agent_list: + try: + await agent.close() + except Exception as ex: + name = getattr( + agent, + "agent_name", + getattr(agent, "__class__", type("X", (object,), {})).__name__, + ) + cls.logger.warning(f"Error closing agent {name}: {ex}") + + agent_list.clear() + cls.logger.info("Agent cleanup completed") diff --git a/src/backend/af/magentic_agents/proxy_agent.py b/src/backend/af/magentic_agents/proxy_agent.py new file mode 100644 index 000000000..02cd90b79 --- /dev/null +++ b/src/backend/af/magentic_agents/proxy_agent.py @@ -0,0 +1,373 @@ +# Copyright (c) Microsoft. All rights reserved. +"""Proxy agent that prompts for human clarification.""" + +import asyncio +import logging +import time +import uuid +from collections.abc import AsyncIterable +from typing import AsyncIterator, Optional + +from pydantic import Field +from semantic_kernel.agents import ( # pylint: disable=no-name-in-module + AgentResponseItem, + AgentThread, +) +from semantic_kernel.agents.agent import Agent +from semantic_kernel.contents import ( + AuthorRole, + ChatMessageContent, + StreamingChatMessageContent, +) +from semantic_kernel.contents.chat_history import ChatHistory +from semantic_kernel.contents.history_reducer.chat_history_reducer import ( + ChatHistoryReducer, +) +from semantic_kernel.exceptions.agent_exceptions import AgentThreadOperationException +from typing_extensions import override +from v3.callbacks.response_handlers import (agent_response_callback, + streaming_agent_response_callback) +from v3.config.settings import connection_config, orchestration_config +from v3.models.messages import (UserClarificationRequest, + UserClarificationResponse, WebsocketMessageType) + +# Initialize logger for the module +logger = logging.getLogger(__name__) + + +class DummyAgentThread(AgentThread): + """Dummy thread implementation for proxy agent.""" + + def __init__( + self, chat_history: ChatHistory | None = None, thread_id: str | None = None + ): + super().__init__() + self._chat_history = chat_history if chat_history is not None else ChatHistory() + self._id: str = thread_id or f"thread_{uuid.uuid4().hex}" + self._is_deleted = False + self.logger = logging.getLogger(__name__) + + @override + async def _create(self) -> str: + """Starts the thread and returns its ID.""" + return self._id + + @override + async def _delete(self) -> None: + """Ends the current thread.""" + self._chat_history.clear() + + @override + async def _on_new_message(self, new_message: str | ChatMessageContent) -> None: + """Called when a new message has been contributed to the chat.""" + if isinstance(new_message, str): + new_message = ChatMessageContent(role=AuthorRole.USER, content=new_message) + + if ( + not new_message.metadata + or "thread_id" not in new_message.metadata + or new_message.metadata["thread_id"] != self._id + ): + self._chat_history.add_message(new_message) + + async def get_messages(self) -> AsyncIterable[ChatMessageContent]: + """Retrieve the current chat history. + + Returns: + An async iterable of ChatMessageContent. + """ + if self._is_deleted: + raise AgentThreadOperationException( + "Cannot retrieve chat history, since the thread has been deleted." + ) + if self._id is None: + await self.create() + for message in self._chat_history.messages: + yield message + + async def reduce(self) -> ChatHistory | None: + """Reduce the chat history to a smaller size.""" + if self._id is None: + raise AgentThreadOperationException( + "Cannot reduce chat history, since the thread is not currently active." + ) + if not isinstance(self._chat_history, ChatHistoryReducer): + return None + return await self._chat_history.reduce() + + +class ProxyAgentResponseItem: + """Response item wrapper for proxy agent responses.""" + + def __init__(self, message: ChatMessageContent, thread: AgentThread): + self.message = message + self.thread = thread + self.logger = logging.getLogger(__name__) + + +class ProxyAgent(Agent): + """Simple proxy agent that prompts for human clarification.""" + + # Declare as Pydantic field + user_id: str = Field( + default=None, description="User ID for WebSocket messaging" + ) + + def __init__(self, user_id: str, **kwargs): + # Get user_id from parameter, fallback to empty string + effective_user_id = user_id or "" + super().__init__( + name="ProxyAgent", + description="Call this agent when you need to clarify requests by asking the human user for more information. Ask it for more details about any unclear requirements, missing information, or if you need the user to elaborate on any aspect of the task.", + user_id=effective_user_id, + **kwargs, + ) + self.instructions = "" + + def _create_message_content( + self, content: str, thread_id: str = None + ) -> ChatMessageContent: + """Create a ChatMessageContent with proper metadata.""" + return ChatMessageContent( + role=AuthorRole.ASSISTANT, + content=content, + name=self.name, + metadata={"thread_id": thread_id} if thread_id else {}, + ) + + async def _trigger_response_callbacks(self, message_content: ChatMessageContent): + """Manually trigger the same response callbacks used by other agents.""" + # Get current user_id dynamically instead of using stored value + current_user = self.user_id or "" + + # Trigger the standard agent response callback + agent_response_callback(message_content, current_user) + + async def _trigger_streaming_callbacks(self, content: str, is_final: bool = False): + """Manually trigger streaming callbacks for real-time updates.""" + # Get current user_id dynamically instead of using stored value + current_user = self.user_id or "" + streaming_message = StreamingChatMessageContent( + role=AuthorRole.ASSISTANT, content=content, name=self.name, choice_index=0 + ) + await streaming_agent_response_callback( + streaming_message, is_final, current_user + ) + + async def invoke( + self, message: str, *, thread: AgentThread | None = None, **kwargs + ) -> AsyncIterator[ChatMessageContent]: + """Ask human user for clarification about the message.""" + + thread = await self._ensure_thread_exists_with_messages( + messages=message, + thread=thread, + construct_thread=lambda: DummyAgentThread(), + expected_type=DummyAgentThread, + ) + + # Send clarification request via streaming callbacks + clarification_request = f"I need clarification about: {message}" + + clarification_message = UserClarificationRequest( + question=clarification_request, + request_id=str(uuid.uuid4()), # Unique ID for the request + ) + + # Send the approval request to the user's WebSocket + await connection_config.send_status_update_async( + { + "type": WebsocketMessageType.USER_CLARIFICATION_REQUEST, + "data": clarification_message, + }, + user_id=self.user_id, + message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST, + ) + + # Get human input + human_response = await self._wait_for_user_clarification( + clarification_message.request_id + ) + + # Handle silent timeout/cancellation + if human_response is None: + # Process was terminated silently - don't yield any response + logger.debug("Clarification process terminated silently - ending invoke") + return + + # Extract the answer from the response + answer = human_response.answer if human_response else "No additional clarification provided." + + response = f"Human clarification: {answer}" + + chat_message = self._create_message_content(response, thread.id) + + yield AgentResponseItem(message=chat_message, thread=thread) + + async def invoke_stream( + self, messages, thread=None, **kwargs + ) -> AsyncIterator[ProxyAgentResponseItem]: + """Stream version - handles thread management for orchestration.""" + + thread = await self._ensure_thread_exists_with_messages( + messages=messages, + thread=thread, + construct_thread=lambda: DummyAgentThread(), + expected_type=DummyAgentThread, + ) + + # Extract message content + if isinstance(messages, list) and messages: + message = ( + messages[-1].content + if hasattr(messages[-1], "content") + else str(messages[-1]) + ) + elif isinstance(messages, str): + message = messages + else: + message = str(messages) + + # Send clarification request via streaming callbacks + clarification_request = f"I need clarification about: {message}" + + clarification_message = UserClarificationRequest( + question=clarification_request, + request_id=str(uuid.uuid4()), # Unique ID for the request + ) + + # Send the approval request to the user's WebSocket + # The user_id will be automatically retrieved from context + await connection_config.send_status_update_async( + { + "type": WebsocketMessageType.USER_CLARIFICATION_REQUEST, + "data": clarification_message, + }, + user_id=self.user_id, + message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST, + ) + + # Get human input - replace with websocket call when available + human_response = await self._wait_for_user_clarification( + clarification_message.request_id + ) + + # Handle silent timeout/cancellation + if human_response is None: + # Process was terminated silently - don't yield any response + logger.debug("Clarification process terminated silently - ending invoke_stream") + return + + # Extract the answer from the response + answer = human_response.answer if human_response else "No additional clarification provided." + + response = f"Human clarification: {answer}" + + chat_message = self._create_message_content(response, thread.id) + + yield AgentResponseItem(message=chat_message, thread=thread) + + async def _wait_for_user_clarification( + self, request_id: str + ) -> Optional[UserClarificationResponse]: + """ + Wait for user clarification response using event-driven pattern with timeout handling. + + Args: + request_id: The request ID to wait for clarification + + Returns: + UserClarificationResponse: Clarification result with request ID and answer + + Raises: + asyncio.TimeoutError: If timeout is exceeded (300 seconds default) + """ + # logger.info(f"Waiting for user clarification for request: {request_id}") + + # Initialize clarification as pending using the new event-driven method + orchestration_config.set_clarification_pending(request_id) + + try: + # Wait for clarification with timeout using the new event-driven method + answer = await orchestration_config.wait_for_clarification(request_id) + + # logger.info(f"Clarification received for request {request_id}: {answer}") + return UserClarificationResponse( + request_id=request_id, + answer=answer, + ) + except asyncio.TimeoutError: + # Enhanced timeout handling - notify user via WebSocket and cleanup + logger.debug(f"Clarification timeout for request {request_id} - notifying user and terminating process") + + # Create timeout notification message + from v3.models.messages import TimeoutNotification, WebsocketMessageType + timeout_notification = TimeoutNotification( + timeout_type="clarification", + request_id=request_id, + message=f"User clarification request timed out after {orchestration_config.default_timeout} seconds. Please try again.", + timestamp=time.time(), + timeout_duration=orchestration_config.default_timeout + ) + + # Send timeout notification to user via WebSocket + try: + await connection_config.send_status_update_async( + message=timeout_notification, + user_id=self.user_id, + message_type=WebsocketMessageType.TIMEOUT_NOTIFICATION, + ) + logger.info(f"Timeout notification sent to user {self.user_id} for clarification {request_id}") + except Exception as e: + logger.error(f"Failed to send timeout notification: {e}") + + # Clean up this specific request + orchestration_config.cleanup_clarification(request_id) + + # Return None to indicate silent termination + # The timeout naturally stops this specific wait operation without affecting other tasks + return None + + except KeyError as e: + # Silent error handling for invalid request IDs + logger.debug(f"Request ID not found: {e} - terminating process silently") + return None + + except asyncio.CancelledError: + # Handle task cancellation gracefully + logger.debug(f"Clarification request {request_id} was cancelled") + orchestration_config.cleanup_clarification(request_id) + return None + + except Exception as e: + # Silent error handling for unexpected errors + logger.debug(f"Unexpected error waiting for clarification: {e} - terminating process silently") + orchestration_config.cleanup_clarification(request_id) + return None + finally: + # Ensure cleanup happens for any incomplete requests + # This provides an additional safety net for resource cleanup + if (request_id in orchestration_config.clarifications and orchestration_config.clarifications[request_id] is None): + logger.debug(f"Final cleanup for pending clarification request {request_id}") + orchestration_config.cleanup_clarification(request_id) + + async def get_response(self, chat_history, **kwargs): + """Get response from the agent - required by Agent base class.""" + # Extract the latest user message + latest_message = ( + chat_history.messages[-1].content if chat_history.messages else "" + ) + + # Use our invoke method to get the response + async for response in self.invoke(latest_message, **kwargs): + return response + + # Fallback if no response generated + return ChatMessageContent( + role=AuthorRole.ASSISTANT, content="No clarification provided." + ) + + +async def create_proxy_agent(user_id: str = None): + """Factory function for human proxy agent.""" + return ProxyAgent(user_id=user_id) diff --git a/src/backend/af/magentic_agents/reasoning_agent.py b/src/backend/af/magentic_agents/reasoning_agent.py new file mode 100644 index 000000000..b29b002ff --- /dev/null +++ b/src/backend/af/magentic_agents/reasoning_agent.py @@ -0,0 +1,111 @@ +import logging + +from common.config.app_config import config +from semantic_kernel import Kernel +from semantic_kernel.agents import ChatCompletionAgent # pylint: disable=E0611 +from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion +from v3.magentic_agents.common.lifecycle import MCPEnabledBase +from v3.magentic_agents.models.agent_models import MCPConfig, SearchConfig +from v3.magentic_agents.reasoning_search import ReasoningSearch +from v3.config.agent_registry import agent_registry + + +class ReasoningAgentTemplate(MCPEnabledBase): + """ + SK ChatCompletionAgent with optional MCP plugin injected as a Kernel plugin. + No Azure AI Agents client is needed here. We only need a token provider for SK. + """ + + def __init__( + self, + agent_name: str, + agent_description: str, + agent_instructions: str, + model_deployment_name: str, + azure_openai_endpoint: str, + search_config: SearchConfig | None = None, + mcp_config: MCPConfig | None = None, + ) -> None: + super().__init__(mcp=mcp_config) + self.agent_name = agent_name + self.agent_description = agent_description + self.agent_instructions = agent_instructions + self._model_deployment_name = model_deployment_name + self._openai_endpoint = azure_openai_endpoint + self.search_config = search_config + self.reasoning_search: ReasoningSearch | None = None + self.logger = logging.getLogger(__name__) + + def ad_token_provider(self) -> str: + credential = config.get_azure_credentials() + token = credential.get_token(config.AZURE_COGNITIVE_SERVICES) + return token.token + + async def _after_open(self) -> None: + self.kernel = Kernel() + + # Add Azure OpenAI Chat Completion service + chat = AzureChatCompletion( + deployment_name=self._model_deployment_name, + endpoint=self._openai_endpoint, + ad_token_provider=self.ad_token_provider, + ) + self.kernel.add_service(chat) + + # Initialize search capabilities + if self.search_config: + self.reasoning_search = ReasoningSearch(self.search_config) + await self.reasoning_search.initialize(self.kernel) + + # Inject MCP plugin into the SK kernel if available + if self.mcp_plugin: + try: + self.kernel.add_plugin(self.mcp_plugin, plugin_name="mcp_tools") + self.logger.info("Added MCP plugin") + except Exception as ex: + self.logger.exception(f"Could not add MCP plugin to kernel: {ex}") + + self._agent = ChatCompletionAgent( + kernel=self.kernel, + name=self.agent_name, + description=self.agent_description, + instructions=self.agent_instructions, + ) + + # Register agent with global registry for tracking and cleanup + try: + agent_registry.register_agent(self) + self.logger.info(f"πŸ“ Registered agent '{self.agent_name}' with global registry") + except Exception as registry_error: + self.logger.warning(f"⚠️ Failed to register agent '{self.agent_name}' with registry: {registry_error}") + + async def invoke(self, message: str): + """Invoke the agent with a message.""" + if not self._agent: + raise RuntimeError("Agent not initialized. Call open() first.") + + async for response in self._agent.invoke(message): + yield response + + +# Backward‑compatible factory +async def create_reasoning_agent( + agent_name: str, + agent_description: str, + agent_instructions: str, + model_deployment_name: str, + azure_openai_endpoint: str, + search_config: SearchConfig | None = None, + mcp_config: MCPConfig | None = None, +) -> ReasoningAgentTemplate: + agent = ReasoningAgentTemplate( + agent_name=agent_name, + agent_description=agent_description, + agent_instructions=agent_instructions, + model_deployment_name=model_deployment_name, + azure_openai_endpoint=azure_openai_endpoint, + search_config=search_config, + mcp_config=mcp_config, + ) + await agent.open() + return agent diff --git a/src/backend/af/magentic_agents/reasoning_search.py b/src/backend/af/magentic_agents/reasoning_search.py new file mode 100644 index 000000000..7f944e7f5 --- /dev/null +++ b/src/backend/af/magentic_agents/reasoning_search.py @@ -0,0 +1,93 @@ +""" +RAG search capabilities for ReasoningAgentTemplate using AzureAISearchCollection. +Based on Semantic Kernel text search patterns. +""" + +from azure.core.credentials import AzureKeyCredential +from azure.search.documents import SearchClient +from semantic_kernel import Kernel +from semantic_kernel.functions import kernel_function +from v3.magentic_agents.models.agent_models import SearchConfig + + +class ReasoningSearch: + """Handles Azure AI Search integration for reasoning agents.""" + + def __init__(self, search_config: SearchConfig | None = None): + self.search_config = search_config + self.search_client: SearchClient | None = None + + async def initialize(self, kernel: Kernel) -> bool: + """Initialize the search collection with embeddings and add it to the kernel.""" + if ( + not self.search_config + or not self.search_config.endpoint + or not self.search_config.index_name + ): + print("Search configuration not available") + return False + + try: + + self.search_client = SearchClient( + endpoint=self.search_config.endpoint, + credential=AzureKeyCredential(self.search_config.api_key), + index_name=self.search_config.index_name, + ) + + # Add this class as a plugin so the agent can call search_documents + kernel.add_plugin(self, plugin_name="knowledge_search") + + print( + f"Added Azure AI Search plugin for index: {self.search_config.index_name}" + ) + return True + + except Exception as ex: + print(f"Could not initialize Azure AI Search: {ex}") + return False + + @kernel_function( + name="search_documents", + description="Search the knowledge base for relevant documents and information. Use this when you need to find specific information from internal documents or data.", + ) + async def search_documents(self, query: str, limit: str = "3") -> str: + """Search function that the agent can invoke to find relevant documents.""" + if not self.search_client: + return "Search service is not available." + + try: + limit_int = int(limit) + search_results = [] + + results = self.search_client.search( + search_text=query, + query_type="simple", + select=["content"], + top=limit_int, + ) + + for result in results: + search_results.append(f"content: {result['content']}") + + if not search_results: + return f"No relevant documents found for query: '{query}'" + + return search_results + + except Exception as ex: + return f"Search failed: {str(ex)}" + + def is_available(self) -> bool: + """Check if search functionality is available.""" + return self.search_client is not None + + +# Simple factory function +async def create_reasoning_search( + kernel: Kernel, search_config: SearchConfig | None +) -> ReasoningSearch: + """Create and initialize a ReasoningSearch instance.""" + search = ReasoningSearch(search_config) + await search.initialize(kernel) + return search diff --git a/src/backend/af/models/messages.py b/src/backend/af/models/messages.py new file mode 100644 index 000000000..2d92f4a68 --- /dev/null +++ b/src/backend/af/models/messages.py @@ -0,0 +1,206 @@ +"""Messages from the backend to the frontend via WebSocket (agent_framework variant).""" + +import time +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel + +# Use the agent-framework friendly models (previously from messages_kernel) +from common.models.messages_af import AgentMessageType +from af.models.models import MPlan, PlanStatus + + +# --------------------------------------------------------------------------- +# Dataclass message payloads +# --------------------------------------------------------------------------- + +@dataclass(slots=True) +class AgentMessage: + """Message from the backend to the frontend via WebSocket.""" + agent_name: str + timestamp: str + content: str + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass(slots=True) +class AgentStreamStart: + """Start of a streaming message.""" + agent_name: str + + +@dataclass(slots=True) +class AgentStreamEnd: + """End of a streaming message.""" + agent_name: str + + +@dataclass(slots=True) +class AgentMessageStreaming: + """Streaming chunk from an agent.""" + agent_name: str + content: str + is_final: bool = False + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass(slots=True) +class AgentToolMessage: + """Message representing that an agent produced one or more tool calls.""" + agent_name: str + tool_calls: List["AgentToolCall"] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass(slots=True) +class AgentToolCall: + """A single tool invocation.""" + tool_name: str + arguments: Dict[str, Any] + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass(slots=True) +class PlanApprovalRequest: + """Request for plan approval from the frontend.""" + plan: MPlan + status: PlanStatus + context: dict | None = None + + +@dataclass(slots=True) +class PlanApprovalResponse: + """Response for plan approval from the frontend.""" + m_plan_id: str + approved: bool + feedback: str | None = None + plan_id: str | None = None + + +@dataclass(slots=True) +class ReplanApprovalRequest: + """Request for replan approval from the frontend.""" + new_plan: MPlan + reason: str + context: dict | None = None + + +@dataclass(slots=True) +class ReplanApprovalResponse: + """Response for replan approval from the frontend.""" + plan_id: str + approved: bool + feedback: str | None = None + + +@dataclass(slots=True) +class UserClarificationRequest: + """Request for user clarification from the frontend.""" + question: str + request_id: str + + +@dataclass(slots=True) +class UserClarificationResponse: + """Response for user clarification from the frontend.""" + request_id: str + answer: str = "" + plan_id: str = "" + m_plan_id: str = "" + + +@dataclass(slots=True) +class FinalResultMessage: + """Final result message from the backend to the frontend.""" + content: str + status: str = "completed" + timestamp: Optional[float] = None + summary: str | None = None + + def to_dict(self) -> Dict[str, Any]: + data = { + "content": self.content, + "status": self.status, + "timestamp": self.timestamp or time.time(), + } + if self.summary: + data["summary"] = self.summary + return data + + +# --------------------------------------------------------------------------- +# Pydantic model replacing the previous KernelBaseModel +# --------------------------------------------------------------------------- + +class ApprovalRequest(BaseModel): + """Message sent to HumanAgent to request approval for a step.""" + step_id: str + plan_id: str + session_id: str + user_id: str + action: str + agent_name: str + + def to_dict(self) -> Dict[str, Any]: + # Consistent with dataclass pattern + return self.model_dump() + + +@dataclass(slots=True) +class AgentMessageResponse: + """Response message representing an agent's message (stream or final).""" + plan_id: str + agent: str + content: str + agent_type: AgentMessageType + is_final: bool = False + raw_data: str | None = None + streaming_message: str | None = None + + +@dataclass(slots=True) +class TimeoutNotification: + """Notification about a timeout (approval or clarification).""" + timeout_type: str # "approval" or "clarification" + request_id: str # plan_id or request_id + message: str # description + timestamp: float # epoch time + timeout_duration: float # seconds waited + + def to_dict(self) -> Dict[str, Any]: + return { + "timeout_type": self.timeout_type, + "request_id": self.request_id, + "message": self.message, + "timestamp": self.timestamp, + "timeout_duration": self.timeout_duration + } + + +class WebsocketMessageType(str, Enum): + """Types of WebSocket messages.""" + SYSTEM_MESSAGE = "system_message" + AGENT_MESSAGE = "agent_message" + AGENT_STREAM_START = "agent_stream_start" + AGENT_STREAM_END = "agent_stream_end" + AGENT_MESSAGE_STREAMING = "agent_message_streaming" + AGENT_TOOL_MESSAGE = "agent_tool_message" + PLAN_APPROVAL_REQUEST = "plan_approval_request" + PLAN_APPROVAL_RESPONSE = "plan_approval_response" + REPLAN_APPROVAL_REQUEST = "replan_approval_request" + REPLAN_APPROVAL_RESPONSE = "replan_approval_response" + USER_CLARIFICATION_REQUEST = "user_clarification_request" + USER_CLARIFICATION_RESPONSE = "user_clarification_response" + FINAL_RESULT_MESSAGE = "final_result_message" + TIMEOUT_NOTIFICATION = "timeout_notification" + diff --git a/src/backend/af/models/models.py b/src/backend/af/models/models.py new file mode 100644 index 000000000..adcf1fe88 --- /dev/null +++ b/src/backend/af/models/models.py @@ -0,0 +1,35 @@ +import uuid +from enum import Enum +from typing import List + +from pydantic import BaseModel, Field + + +class PlanStatus(str, Enum): + CREATED = "created" + QUEUED = "queued" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class MStep(BaseModel): + """model of a step in a plan""" + + agent: str = "" + action: str = "" + + +class MPlan(BaseModel): + """model of a plan""" + + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + user_id: str = "" + team_id: str = "" + plan_id: str = "" + overall_status: PlanStatus = PlanStatus.CREATED + user_request: str = "" + team: List[str] = [] + facts: str = "" + steps: List[MStep] = [] diff --git a/src/backend/af/models/orchestration_models.py b/src/backend/af/models/orchestration_models.py new file mode 100644 index 000000000..4eb2846e6 --- /dev/null +++ b/src/backend/af/models/orchestration_models.py @@ -0,0 +1,56 @@ +""" +Agent Framework version of orchestration models. + +Removes dependency on semantic_kernel.kernel_pydantic.KernelBaseModel and +uses standard Pydantic BaseModel + a lightweight dataclass for simple value objects. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional + +from pydantic import BaseModel + + +# --------------------------------------------------------------------------- +# Core lightweight value object +# --------------------------------------------------------------------------- + +@dataclass(slots=True) +class AgentDefinition: + """Simple agent descriptor used in planning output.""" + name: str + description: str + + def __repr__(self) -> str: # Keep original style + return f"Agent(name={self.name!r}, description={self.description!r})" + + +# --------------------------------------------------------------------------- +# Planner response models +# --------------------------------------------------------------------------- + +class PlannerResponseStep(BaseModel): + """One planned step referencing an agent and an action to perform.""" + agent: AgentDefinition + action: str + + +class PlannerResponsePlan(BaseModel): + """ + Full planner output including: + - original request + - selected team (list of AgentDefinition) + - extracted facts + - ordered steps + - summarization + - optional human clarification request + """ + request: str + team: List[AgentDefinition] + facts: str + steps: List[PlannerResponseStep] + summary_plan_and_steps: str + human_clarification_request: Optional[str] = None + diff --git a/src/backend/af/orchestration/__init__.py b/src/backend/af/orchestration/__init__.py new file mode 100644 index 000000000..47a4396bc --- /dev/null +++ b/src/backend/af/orchestration/__init__.py @@ -0,0 +1 @@ +# Orchestration package for Magentic orchestration management diff --git a/src/backend/af/orchestration/helper/plan_to_mplan_converter.py b/src/backend/af/orchestration/helper/plan_to_mplan_converter.py new file mode 100644 index 000000000..bc1dd5346 --- /dev/null +++ b/src/backend/af/orchestration/helper/plan_to_mplan_converter.py @@ -0,0 +1,194 @@ +import logging +import re +from typing import Iterable, List, Optional + +from v3.models.models import MPlan, MStep + +logger = logging.getLogger(__name__) + + +class PlanToMPlanConverter: + """ + Convert a free-form, bullet-style plan string into an MPlan object. + + Bullet parsing rules: + 1. Recognizes lines starting (optionally with indentation) followed by -, *, or β€’ + 2. Attempts to resolve the agent in priority order: + a. First bolded token (**AgentName**) if within detection window and in team + b. Any team agent name appearing (case-insensitive) within the first detection window chars + c. Fallback agent name (default 'MagenticAgent') + 3. Removes the matched agent token from the action text + 4. Ignores bullet lines whose remaining action is blank + + Notes: + - This does not mutate MPlan.user_id (caller can assign after parsing). + - You can supply task text (becomes user_request) and facts text. + - Optionally detect sub-bullets (indent > 0). If enabled, a `level` integer is + returned alongside each MStep in an auxiliary `step_levels` list (since the + current MStep model doesn’t have a level field). + + Example: + converter = PlanToMPlanConverter(team=["ResearchAgent","AnalysisAgent"]) + mplan = converter.parse(plan_text=raw, task="Analyze Q4", facts="Some facts") + + """ + + BULLET_RE = re.compile(r"^(?P\s*)[-β€’*]\s+(?P.+)$") + BOLD_AGENT_RE = re.compile(r"\*\*([A-Za-z0-9_]+)\*\*") + STRIP_BULLET_MARKER_RE = re.compile(r"^[-β€’*]\s+") + + def __init__( + self, + team: Iterable[str], + task: str = "", + facts: str = "", + detection_window: int = 25, + fallback_agent: str = "MagenticAgent", + enable_sub_bullets: bool = False, + trim_actions: bool = True, + collapse_internal_whitespace: bool = True, + ): + self.team: List[str] = list(team) + self.task = task + self.facts = facts + self.detection_window = detection_window + self.fallback_agent = fallback_agent + self.enable_sub_bullets = enable_sub_bullets + self.trim_actions = trim_actions + self.collapse_internal_whitespace = collapse_internal_whitespace + + # Map for faster case-insensitive lookups while preserving canonical form + self._team_lookup = {t.lower(): t for t in self.team} + + # ---------------- Public API ---------------- # + + def parse(self, plan_text: str) -> MPlan: + """ + Parse the supplied bullet-style plan text into an MPlan. + + Returns: + MPlan with team, user_request, facts, steps populated. + + Side channel (if sub-bullets enabled): + self.last_step_levels: List[int] parallel to steps (0 = top, 1 = sub, etc.) + """ + mplan = MPlan() + mplan.team = self.team.copy() + mplan.user_request = self.task or mplan.user_request + mplan.facts = self.facts or mplan.facts + + lines = self._preprocess_lines(plan_text) + + step_levels: List[int] = [] + for raw_line in lines: + bullet_match = self.BULLET_RE.match(raw_line) + if not bullet_match: + continue # ignore non-bullet lines entirely + + indent = bullet_match.group("indent") or "" + body = bullet_match.group("body").strip() + + level = 0 + if self.enable_sub_bullets and indent: + # Simple heuristic: any indentation => level 1 (could extend to deeper) + level = 1 + + agent, action = self._extract_agent_and_action(body) + + if not action: + continue + + mplan.steps.append(MStep(agent=agent, action=action)) + if self.enable_sub_bullets: + step_levels.append(level) + + if self.enable_sub_bullets: + # Expose levels so caller can correlate (parallel list) + self.last_step_levels = step_levels # type: ignore[attr-defined] + + return mplan + + # ---------------- Internal Helpers ---------------- # + + def _preprocess_lines(self, plan_text: str) -> List[str]: + lines = plan_text.splitlines() + cleaned: List[str] = [] + for line in lines: + stripped = line.rstrip() + if stripped: + cleaned.append(stripped) + return cleaned + + def _extract_agent_and_action(self, body: str) -> (str, str): + """ + Apply bold-first strategy, then window scan fallback. + Returns (agent, action_text). + """ + original = body + + # 1. Try bold token + agent, body_after = self._try_bold_agent(original) + if agent: + action = self._finalize_action(body_after) + return agent, action + + # 2. Try window scan + agent2, body_after2 = self._try_window_agent(original) + if agent2: + action = self._finalize_action(body_after2) + return agent2, action + + # 3. Fallback + action = self._finalize_action(original) + return self.fallback_agent, action + + def _try_bold_agent(self, text: str) -> (Optional[str], str): + m = self.BOLD_AGENT_RE.search(text) + if not m: + return None, text + if m.start() <= self.detection_window: + candidate = m.group(1) + canonical = self._team_lookup.get(candidate.lower()) + if canonical: # valid agent + cleaned = text[: m.start()] + text[m.end() :] + return canonical, cleaned.strip() + return None, text + + def _try_window_agent(self, text: str) -> (Optional[str], str): + head_segment = text[: self.detection_window].lower() + for canonical in self.team: + if canonical.lower() in head_segment: + # Remove first occurrence (case-insensitive) + pattern = re.compile(re.escape(canonical), re.IGNORECASE) + cleaned = pattern.sub("", text, count=1) + cleaned = cleaned.replace("*", "") + return canonical, cleaned.strip() + return None, text + + def _finalize_action(self, action: str) -> str: + if self.trim_actions: + action = action.strip() + if self.collapse_internal_whitespace: + action = re.sub(r"\s+", " ", action) + return action + + # --------------- Convenience (static) --------------- # + + @staticmethod + def convert( + plan_text: str, + team: Iterable[str], + task: str = "", + facts: str = "", + **kwargs, + ) -> MPlan: + """ + One-shot convenience method: + mplan = PlanToMPlanConverter.convert(plan_text, team, task="X") + """ + return PlanToMPlanConverter( + team=team, + task=task, + facts=facts, + **kwargs, + ).parse(plan_text) diff --git a/src/backend/af/orchestration/human_approval_manager.py b/src/backend/af/orchestration/human_approval_manager.py new file mode 100644 index 000000000..bfba4befe --- /dev/null +++ b/src/backend/af/orchestration/human_approval_manager.py @@ -0,0 +1,370 @@ +""" +Human-in-the-loop Magentic Manager for employee onboarding orchestration. +Extends StandardMagenticManager to add approval gates before plan execution. +""" + +import asyncio +import logging +from typing import Any, Optional + +import v3.models.messages as messages +from semantic_kernel.agents.orchestration.magentic import ( + MagenticContext, + ProgressLedger, + ProgressLedgerItem, + StandardMagenticManager, +) +from semantic_kernel.agents.orchestration.prompts._magentic_prompts import ( + ORCHESTRATOR_FINAL_ANSWER_PROMPT, + ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT, + ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT, +) +from semantic_kernel.contents import ChatMessageContent +from v3.config.settings import connection_config, orchestration_config +from v3.models.models import MPlan +from v3.orchestration.helper.plan_to_mplan_converter import \ + PlanToMPlanConverter + +# Using a module level logger to avoid pydantic issues around inherited fields +logger = logging.getLogger(__name__) + + +# Create a progress ledger that indicates the request is satisfied (task completed) +class HumanApprovalMagenticManager(StandardMagenticManager): + """ + Extended Magentic manager that requires human approval before executing plan steps. + Provides interactive approval for each step in the orchestration plan. + """ + + # Define Pydantic fields to avoid validation errors + approval_enabled: bool = True + magentic_plan: Optional[MPlan] = None + current_user_id: str + + def __init__(self, user_id: str, *args, **kwargs): + """ + Initialize the HumanApprovalMagenticManager. + Args: + user_id: ID of the user to associate with this orchestration instance. + *args: Additional positional arguments for the parent StandardMagenticManager. + **kwargs: Additional keyword arguments for the parent StandardMagenticManager. + """ + + # Remove any custom kwargs before passing to parent + + plan_append = """ +IMPORTANT: Never ask the user for information or clarification until all agents on the team have been asked first. + +EXAMPLE: If the user request involves product information, first ask all agents on the team to provide the information. +Do not ask the user unless all agents have been consulted and the information is still missing. + +Plan steps should always include a bullet point, followed by an agent name, followed by a description of the action +to be taken. If a step involves multiple actions, separate them into distinct steps with an agent included in each step. +If the step is taken by an agent that is not part of the team, such as the MagenticManager, please always list the MagenticManager as the agent for that step. At any time, if more information is needed from the user, use the ProxyAgent to request this information. + +Here is an example of a well-structured plan: +- **EnhancedResearchAgent** to gather authoritative data on the latest industry trends and best practices in employee onboarding +- **EnhancedResearchAgent** to gather authoritative data on Innovative onboarding techniques that enhance new hire engagement and retention. +- **DocumentCreationAgent** to draft a comprehensive onboarding plan that includes a detailed schedule of onboarding activities and milestones. +- **DocumentCreationAgent** to draft a comprehensive onboarding plan that includes a checklist of resources and materials needed for effective onboarding. +- **ProxyAgent** to review the drafted onboarding plan for clarity and completeness. +- **MagenticManager** to finalize the onboarding plan and prepare it for presentation to stakeholders. + +""" + + final_append = """ + DO NOT EVER OFFER TO HELP FURTHER IN THE FINAL ANSWER! Just provide the final answer and end with a polite closing. +""" + + # kwargs["task_ledger_facts_prompt"] = ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT + facts_append + kwargs["task_ledger_plan_prompt"] = ( + ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT + plan_append + ) + kwargs["task_ledger_plan_update_prompt"] = ( + ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT + plan_append + ) + kwargs["final_answer_prompt"] = ORCHESTRATOR_FINAL_ANSWER_PROMPT + final_append + + kwargs['current_user_id'] = user_id + + super().__init__(*args, **kwargs) + + async def plan(self, magentic_context: MagenticContext) -> Any: + """ + Override the plan method to create the plan first, then ask for approval before execution. + """ + # Extract task text from the context + task_text = magentic_context.task + if hasattr(task_text, "content"): + task_text = task_text.content + elif not isinstance(task_text, str): + task_text = str(task_text) + + logger.info("\n Human-in-the-Loop Magentic Manager Creating Plan:") + logger.info(" Task: %s", task_text) + logger.info("-" * 60) + + # First, let the parent create the actual plan + logger.info(" Creating execution plan...") + plan = await super().plan(magentic_context) + logger.info(" Plan created: %s", plan) + + self.magentic_plan = self.plan_to_obj(magentic_context, self.task_ledger) + + self.magentic_plan.user_id = self.current_user_id + + # Request approval from the user before executing the plan + approval_message = messages.PlanApprovalRequest( + plan=self.magentic_plan, + status="PENDING_APPROVAL", + context=( + { + "task": task_text, + "participant_descriptions": magentic_context.participant_descriptions, + } + if hasattr(magentic_context, "participant_descriptions") + else {} + ), + ) + try: + orchestration_config.plans[self.magentic_plan.id] = self.magentic_plan + except Exception as e: + logger.error("Error processing plan approval: %s", e) + + # Send the approval request to the user's WebSocket + # The user_id will be automatically retrieved from context + await connection_config.send_status_update_async( + message=approval_message, + user_id=self.current_user_id, + message_type=messages.WebsocketMessageType.PLAN_APPROVAL_REQUEST, + ) + + # Wait for user approval + approval_response = await self._wait_for_user_approval(approval_message.plan.id) + + if approval_response and approval_response.approved: + logger.info("Plan approved - proceeding with execution...") + return plan + else: + logger.debug("Plan execution cancelled by user") + await connection_config.send_status_update_async( + { + "type": messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE, + "data": approval_response, + }, + user_id=self.current_user_id, + message_type=messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE, + ) + raise Exception("Plan execution cancelled by user") + + async def replan(self, magentic_context: MagenticContext) -> Any: + """ + Override to add websocket messages for replanning events. + """ + + logger.info("\nHuman-in-the-Loop Magentic Manager replanned:") + replan = await super().replan(magentic_context=magentic_context) + logger.info("Replanned: %s", replan) + return replan + + async def create_progress_ledger( + self, magentic_context: MagenticContext + ) -> ProgressLedger: + """Check for max rounds exceeded and send final message if so.""" + if magentic_context.round_count >= orchestration_config.max_rounds: + # Send final message to user + final_message = messages.FinalResultMessage( + content="Process terminated: Maximum rounds exceeded", + status="terminated", + summary=f"Stopped after {magentic_context.round_count} rounds (max: {orchestration_config.max_rounds})", + ) + + await connection_config.send_status_update_async( + message=final_message, + user_id=self.current_user_id, + message_type=messages.WebsocketMessageType.FINAL_RESULT_MESSAGE, + ) + + return ProgressLedger( + is_request_satisfied=ProgressLedgerItem( + reason="Maximum rounds exceeded", answer=True + ), + is_in_loop=ProgressLedgerItem(reason="Terminating", answer=False), + is_progress_being_made=ProgressLedgerItem( + reason="Terminating", answer=False + ), + next_speaker=ProgressLedgerItem(reason="Task complete", answer=""), + instruction_or_question=ProgressLedgerItem( + reason="Task complete", + answer="Process terminated due to maximum rounds exceeded", + ), + ) + + return await super().create_progress_ledger(magentic_context) + + # plan_id will not be optional in future + async def _wait_for_user_approval( + self, m_plan_id: Optional[str] = None + ) -> Optional[messages.PlanApprovalResponse]: + """ + Wait for user approval response using event-driven pattern with timeout handling. + + Args: + m_plan_id: The plan ID to wait for approval + + Returns: + PlanApprovalResponse: Approval result with approved status and plan ID + + Raises: + asyncio.TimeoutError: If timeout is exceeded (300 seconds default) + """ + logger.info(f"Waiting for user approval for plan: {m_plan_id}") + + if not m_plan_id: + logger.error("No plan ID provided for approval") + return messages.PlanApprovalResponse(approved=False, m_plan_id=m_plan_id) + + # Initialize approval as pending using the new event-driven method + orchestration_config.set_approval_pending(m_plan_id) + + try: + # Wait for approval with timeout using the new event-driven method + approved = await orchestration_config.wait_for_approval(m_plan_id) + + logger.info(f"Approval received for plan {m_plan_id}: {approved}") + return messages.PlanApprovalResponse( + approved=approved, m_plan_id=m_plan_id + ) + except asyncio.TimeoutError: + # Enhanced timeout handling - notify user via WebSocket and cleanup + logger.debug(f"Approval timeout for plan {m_plan_id} - notifying user and terminating process") + + # Create timeout notification message + timeout_message = messages.TimeoutNotification( + timeout_type="approval", + request_id=m_plan_id, + message=f"Plan approval request timed out after {orchestration_config.default_timeout} seconds. Please try again.", + timestamp=asyncio.get_event_loop().time(), + timeout_duration=orchestration_config.default_timeout + ) + + # Send timeout notification to user via WebSocket + try: + await connection_config.send_status_update_async( + message=timeout_message, + user_id=self.current_user_id, + message_type=messages.WebsocketMessageType.TIMEOUT_NOTIFICATION, + ) + logger.info(f"Timeout notification sent to user {self.current_user_id} for plan {m_plan_id}") + except Exception as e: + logger.error(f"Failed to send timeout notification: {e}") + + # Clean up this specific request + orchestration_config.cleanup_approval(m_plan_id) + + # Return None to indicate silent termination + # The timeout naturally stops this specific wait operation without affecting other tasks + return None + + except KeyError as e: + # Silent error handling for invalid plan IDs + logger.debug(f"Plan ID not found: {e} - terminating process silently") + return None + + except asyncio.CancelledError: + # Handle task cancellation gracefully + logger.debug(f"Approval request {m_plan_id} was cancelled") + orchestration_config.cleanup_approval(m_plan_id) + return None + + except Exception as e: + # Silent error handling for unexpected errors + logger.debug(f"Unexpected error waiting for approval: {e} - terminating process silently") + orchestration_config.cleanup_approval(m_plan_id) + return None + finally: + # Ensure cleanup happens for any incomplete requests + # This provides an additional safety net for resource cleanup + if (m_plan_id in orchestration_config.approvals and orchestration_config.approvals[m_plan_id] is None): + logger.debug(f"Final cleanup for pending approval plan {m_plan_id}") + orchestration_config.cleanup_approval(m_plan_id) + + async def prepare_final_answer( + self, magentic_context: MagenticContext + ) -> ChatMessageContent: + """ + Override to ensure final answer is prepared after all steps are executed. + """ + logger.info("\n Magentic Manager - Preparing final answer...") + + return await super().prepare_final_answer(magentic_context) + + def plan_to_obj(self, magentic_context, ledger) -> MPlan: + """Convert the generated plan from the ledger into a structured MPlan object.""" + + return_plan: MPlan = PlanToMPlanConverter.convert( + plan_text=ledger.plan.content, + facts=ledger.facts.content, + team=list(magentic_context.participant_descriptions.keys()), + task=magentic_context.task, + ) + + # # get the request text from the ledger + # if hasattr(magentic_context, 'task'): + # return_plan.user_request = magentic_context.task + + # return_plan.team = list(magentic_context.participant_descriptions.keys()) + + # # Get the facts content from the ledger + # if hasattr(ledger, 'facts') and ledger.facts.content: + # return_plan.facts = ledger.facts.content + + # # Get the plan / steps content from the ledger + # # Split the description into lines and clean them + # lines = [line.strip() for line in ledger.plan.content.strip().split('\n') if line.strip()] + + # found_agent = None + # prefix = None + + # for line in lines: + # found_agent = None + # prefix = None + # # log the line for troubleshooting + # logger.debug("Processing plan line: %s", line) + + # # match only lines that have bullet points + # if re.match(r'^[-β€’*]\s+', line): + # # Remove the bullet point marker + # line = re.sub(r'^[-β€’*]\s+', '', line).strip() + + # # Look for agent names in the line + + # for agent_name in return_plan.team: + # # Check if agent name appears in the line (case insensitive) + # if agent_name.lower() in line[:20].lower(): + # found_agent = agent_name + # line = line.split(agent_name, 1) + # line = line[1].strip() if len(line) > 1 else "" + # line = line.replace('*', '').strip() + # break + + # if not found_agent: + # # If no agent found, assign to ProxyAgent if available + # found_agent = "MagenticAgent" + # # If line indicates a following list of actions (e.g. "Assign **EnhancedResearchAgent** + # # to gather authoritative data on:") save and prefix to the steps + # # if line.endswith(':'): + # # line = line.replace(':', '').strip() + # # prefix = line + " " + + # # Don't create a step if action is blank + # if line.strip() != "": + # if prefix: + # line = prefix + line + # # Create the step object + # step = MStep(agent=found_agent, action=line) + + # # add the step to the plan + # return_plan.steps.append(step) # pylint: disable=E1101 + + return return_plan diff --git a/src/backend/af/orchestration/orchestration_manager.py b/src/backend/af/orchestration/orchestration_manager.py new file mode 100644 index 000000000..7db458fee --- /dev/null +++ b/src/backend/af/orchestration/orchestration_manager.py @@ -0,0 +1,180 @@ +# Copyright (c) Microsoft. All rights reserved. +"""Orchestration manager to handle the orchestration logic.""" +import asyncio +import logging +import uuid +from typing import List, Optional + +from common.config.app_config import config +from common.models.messages_kernel import TeamConfiguration +from semantic_kernel.agents.orchestration.magentic import MagenticOrchestration +from semantic_kernel.agents.runtime import InProcessRuntime + +# Create custom execution settings to fix schema issues +from semantic_kernel.connectors.ai.open_ai import ( + AzureChatCompletion, OpenAIChatPromptExecutionSettings) +from semantic_kernel.contents import (ChatMessageContent, + StreamingChatMessageContent) +from v3.callbacks.response_handlers import (agent_response_callback, + streaming_agent_response_callback) +from v3.config.settings import connection_config, orchestration_config +from v3.magentic_agents.magentic_agent_factory import MagenticAgentFactory +from v3.models.messages import WebsocketMessageType +from v3.orchestration.human_approval_manager import HumanApprovalMagenticManager + + +class OrchestrationManager: + """Manager for handling orchestration logic.""" + + # Class-scoped logger (always available to classmethods) + logger = logging.getLogger(f"{__name__}.OrchestrationManager") + + def __init__(self): + self.user_id: Optional[str] = None + # Optional alias (helps with autocomplete) + self.logger = self.__class__.logger + + @classmethod + async def init_orchestration( + cls, agents: List, user_id: str = None + ) -> MagenticOrchestration: + """Main function to run the agents.""" + + # Custom execution settings that should work with Azure OpenAI + execution_settings = OpenAIChatPromptExecutionSettings( + max_tokens=4000, temperature=0.1 + ) + + credential = config.get_azure_credential(client_id=config.AZURE_CLIENT_ID) + + def get_token(): + token = credential.get_token("https://cognitiveservices.azure.com/.default") + return token.token + + # 1. Create a Magentic orchestration with Azure OpenAI + magentic_orchestration = MagenticOrchestration( + members=agents, + manager=HumanApprovalMagenticManager( + user_id=user_id, + chat_completion_service=AzureChatCompletion( + deployment_name=config.AZURE_OPENAI_DEPLOYMENT_NAME, + endpoint=config.AZURE_OPENAI_ENDPOINT, + ad_token_provider=get_token, # Use token provider function + ), + execution_settings=execution_settings, + ), + agent_response_callback=cls._user_aware_agent_callback(user_id), + streaming_agent_response_callback=cls._user_aware_streaming_callback( + user_id + ), + ) + return magentic_orchestration + + @staticmethod + def _user_aware_agent_callback(user_id: str): + """Factory method that creates a callback with captured user_id""" + + def callback(message: ChatMessageContent): + return agent_response_callback(message, user_id) + + return callback + + @staticmethod + def _user_aware_streaming_callback(user_id: str): + """Factory method that creates a streaming callback with captured user_id""" + + async def callback( + streaming_message: StreamingChatMessageContent, is_final: bool + ): + return await streaming_agent_response_callback( + streaming_message, is_final, user_id + ) + + return callback + + @classmethod + async def get_current_or_new_orchestration( + cls, user_id: str, team_config: TeamConfiguration, team_switched: bool + ) -> MagenticOrchestration: # add team_switched: bool parameter + """get existing orchestration instance.""" + current_orchestration = orchestration_config.get_current_orchestration(user_id) + if ( + current_orchestration is None or team_switched + ): # add check for team_switched flag + if current_orchestration is not None and team_switched: + for agent in current_orchestration._members: + if agent.name != "ProxyAgent": + try: + await agent.close() + except Exception as e: + cls.logger.error("Error closing agent: %s", e) + factory = MagenticAgentFactory() + agents = await factory.get_agents(user_id=user_id, team_config_input=team_config) + orchestration_config.orchestrations[user_id] = await cls.init_orchestration( + agents, user_id + ) + return orchestration_config.get_current_orchestration(user_id) + + async def run_orchestration(self, user_id, input_task) -> None: + """Run the orchestration with user input loop.""" + + job_id = str(uuid.uuid4()) + + # Use the new event-driven method to set approval as pending + orchestration_config.set_approval_pending(job_id) + + magentic_orchestration = orchestration_config.get_current_orchestration(user_id) + + if magentic_orchestration is None: + raise ValueError("Orchestration not initialized for user.") + + try: + if hasattr(magentic_orchestration, "_manager") and hasattr( + magentic_orchestration._manager, "current_user_id" + ): + magentic_orchestration._manager.current_user_id = user_id + self.logger.debug(f"DEBUG: Set user_id on manager = {user_id}") + except Exception as e: + self.logger.error(f"Error setting user_id on manager: {e}") + + runtime = InProcessRuntime() + runtime.start() + + try: + + orchestration_result = await magentic_orchestration.invoke( + task=input_task.description, + runtime=runtime, + ) + + try: + self.logger.info("\nAgent responses:") + value = await orchestration_result.get() + self.logger.info(f"\nFinal result:\n{value}") + self.logger.info("=" * 50) + + # Send final result via WebSocket + await connection_config.send_status_update_async( + { + "type": WebsocketMessageType.FINAL_RESULT_MESSAGE, + "data": { + "content": str(value), + "status": "completed", + "timestamp": asyncio.get_event_loop().time(), + }, + }, + user_id, + message_type=WebsocketMessageType.FINAL_RESULT_MESSAGE, + ) + self.logger.info(f"Final result sent via WebSocket to user {user_id}") + except Exception as e: + self.logger.info(f"Error: {e}") + self.logger.info(f"Error type: {type(e).__name__}") + if hasattr(e, "__dict__"): + self.logger.info(f"Error attributes: {e.__dict__}") + self.logger.info("=" * 50) + + except Exception as e: + self.logger.error(f"Unexpected error: {e}") + finally: + await runtime.stop_when_idle() From 982dcb4b4fcaf985e0d37cfcc3ae7fa55a5cac50 Mon Sep 17 00:00:00 2001 From: Francia Riesco Date: Tue, 21 Oct 2025 10:31:01 -0400 Subject: [PATCH 6/9] Add agent lifecycle, models, and refactor agents Introduces common lifecycle management for agents, new agent configuration models, and refactors ProxyAgent and ReasoningAgentTemplate to use agent_framework primitives. Removes Semantic Kernel dependencies, adds Azure AI Search integration, and streamlines agent creation and invocation logic for improved maintainability and extensibility. --- .../af/magentic_agents/common/lifecycle.py | 188 +++++++ .../af/magentic_agents/models/agent_models.py | 84 +++ src/backend/af/magentic_agents/proxy_agent.py | 502 +++++++----------- .../af/magentic_agents/reasoning_agent.py | 276 +++++++--- .../af/magentic_agents/reasoning_search.py | 200 +++++-- 5 files changed, 822 insertions(+), 428 deletions(-) create mode 100644 src/backend/af/magentic_agents/common/lifecycle.py create mode 100644 src/backend/af/magentic_agents/models/agent_models.py diff --git a/src/backend/af/magentic_agents/common/lifecycle.py b/src/backend/af/magentic_agents/common/lifecycle.py new file mode 100644 index 000000000..17d14676e --- /dev/null +++ b/src/backend/af/magentic_agents/common/lifecycle.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +import os +from contextlib import AsyncExitStack +from typing import Any, Optional + +from azure.ai.projects.aio import AIProjectClient +from azure.identity.aio import DefaultAzureCredential + +from agent_framework.azure import AzureAIAgentClient +from agent_framework import HostedMCPTool + +from af.magentic_agents.models.agent_models import MCPConfig +from af.config.agent_registry import agent_registry + + +class MCPEnabledBase: + """ + Base that owns an AsyncExitStack and (optionally) prepares an MCP tool + for subclasses to attach to ChatOptions (agent_framework style). + Subclasses must implement _after_open() and assign self._agent. + """ + + def __init__(self, mcp: MCPConfig | None = None) -> None: + self._stack: AsyncExitStack | None = None + self.mcp_cfg: MCPConfig | None = mcp + self.mcp_tool: HostedMCPTool | None = None + self._agent: Any | None = None # delegate target (e.g., AzureAIAgentClient) + + async def open(self) -> "MCPEnabledBase": + if self._stack is not None: + return self + self._stack = AsyncExitStack() + self._prepare_mcp_tool() + await self._after_open() + return self + + async def close(self) -> None: + if self._stack is None: + return + try: + # Attempt to close the underlying agent/client if it exposes close() + if self._agent and hasattr(self._agent, "close"): + try: + await self._agent.close() # AzureAIAgentClient has async close + except Exception: # noqa: BLE001 + pass + # Unregister from registry if present + try: + agent_registry.unregister_agent(self) + except Exception: # noqa: BLE001 + pass + await self._stack.aclose() + finally: + self._stack = None + self.mcp_tool = None + self._agent = None + + # Context manager + async def __aenter__(self) -> "MCPEnabledBase": + return await self.open() + + async def __aexit__(self, exc_type, exc, tb) -> None: # noqa: D401 + await self.close() + + # Delegate to underlying agent + def __getattr__(self, name: str) -> Any: + if self._agent is not None: + return getattr(self._agent, name) + raise AttributeError(f"{type(self).__name__} has no attribute '{name}'") + + async def _after_open(self) -> None: + """Subclasses must build self._agent here.""" + raise NotImplementedError + + def _prepare_mcp_tool(self) -> None: + """Translate MCPConfig to a HostedMCPTool (agent_framework construct).""" + if not self.mcp_cfg: + return + try: + self.mcp_tool = HostedMCPTool( + name=self.mcp_cfg.name, + description=self.mcp_cfg.description, + server_label=self.mcp_cfg.name.replace(" ", "_"), + url="", # URL will be resolved via MCPConfig in HostedMCPTool + ) + except Exception: # noqa: BLE001 + self.mcp_tool = None + + +class AzureAgentBase(MCPEnabledBase): + """ + Extends MCPEnabledBase with Azure credential + AIProjectClient contexts. + Subclasses: + - create or attach an Azure AI Agent definition + - instantiate an AzureAIAgentClient and assign to self._agent + - optionally register themselves via agent_registry + """ + + def __init__(self, mcp: MCPConfig | None = None) -> None: + super().__init__(mcp=mcp) + self.creds: Optional[DefaultAzureCredential] = None + self.client: Optional[AIProjectClient] = None + self.project_endpoint: Optional[str] = None + self._created_ephemeral: bool = False # reserved if you add ephemeral agent cleanup + + async def open(self) -> "AzureAgentBase": + if self._stack is not None: + return self + self._stack = AsyncExitStack() + + # Resolve Azure AI Project endpoint (mirrors old SK env var usage) + self.project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") + if not self.project_endpoint: + raise RuntimeError( + "AZURE_AI_PROJECT_ENDPOINT environment variable is required for AzureAgentBase." + ) + + # Acquire credential + self.creds = DefaultAzureCredential() + await self._stack.enter_async_context(self.creds) + + # Create AIProjectClient + self.client = AIProjectClient( + endpoint=self.project_endpoint, + credential=self.creds, + ) + await self._stack.enter_async_context(self.client) + + # Prepare MCP + self._prepare_mcp_tool() + + # Let subclass build agent client + await self._after_open() + + # Register agent (best effort) + try: + agent_registry.register_agent(self) + except Exception: # noqa: BLE001 + pass + + return self + + async def close(self) -> None: + """ + Close agent client and Azure resources. + If you implement ephemeral agent creation in subclasses, you can + optionally delete the agent definition here. + """ + try: + # Example optional clean up of an agent id: + # if self._agent and isinstance(self._agent, AzureAIAgentClient) and self._agent._should_delete_agent: + # try: + # if self.client and self._agent.agent_id: + # await self.client.agents.delete_agent(self._agent.agent_id) + # except Exception: + # pass + + # Close underlying client via base close + if self._agent and hasattr(self._agent, "close"): + try: + await self._agent.close() + except Exception: # noqa: BLE001 + pass + + # Unregister from registry + try: + agent_registry.unregister_agent(self) + except Exception: # noqa: BLE001 + pass + + # Close credential and project client + if self.client: + try: + await self.client.close() + except Exception: # noqa: BLE001 + pass + if self.creds: + try: + await self.creds.close() + except Exception: # noqa: BLE001 + pass + + finally: + await super().close() + self.client = None + self.creds = None + self.project_endpoint = None diff --git a/src/backend/af/magentic_agents/models/agent_models.py b/src/backend/af/magentic_agents/models/agent_models.py new file mode 100644 index 000000000..40f19161d --- /dev/null +++ b/src/backend/af/magentic_agents/models/agent_models.py @@ -0,0 +1,84 @@ +"""Models for agent configurations.""" + +from dataclasses import dataclass + +from common.config.app_config import config + + +@dataclass(slots=True) +class MCPConfig: + """Configuration for connecting to an MCP server.""" + + url: str = "" + name: str = "MCP" + description: str = "" + tenant_id: str = "" + client_id: str = "" + + @classmethod + def from_env(cls) -> "MCPConfig": + url = config.MCP_SERVER_ENDPOINT + name = config.MCP_SERVER_NAME + description = config.MCP_SERVER_DESCRIPTION + tenant_id = config.AZURE_TENANT_ID + client_id = config.AZURE_CLIENT_ID + + # Raise exception if any required environment variable is missing + if not all([url, name, description, tenant_id, client_id]): + raise ValueError(f"{cls.__name__} Missing required environment variables") + + return cls( + url=url, + name=name, + description=description, + tenant_id=tenant_id, + client_id=client_id, + ) + + +# @dataclass(slots=True) +# class BingConfig: +# """Configuration for connecting to Bing Search.""" +# connection_name: str = "Bing" + +# @classmethod +# def from_env(cls) -> "BingConfig": +# connection_name = config.BING_CONNECTION_NAME + +# # Raise exception if required environment variable is missing +# if not connection_name: +# raise ValueError(f"{cls.__name__} Missing required environment variables") + +# return cls( +# connection_name=connection_name, +# ) + + +@dataclass(slots=True) +class SearchConfig: + """Configuration for connecting to Azure AI Search.""" + + connection_name: str | None = None + endpoint: str | None = None + index_name: str | None = None + api_key: str | None = None # API key for Azure AI Search + + @classmethod + def from_env(cls) -> "SearchConfig": + connection_name = config.AZURE_AI_SEARCH_CONNECTION_NAME + index_name = config.AZURE_AI_SEARCH_INDEX_NAME + endpoint = config.AZURE_AI_SEARCH_ENDPOINT + api_key = config.AZURE_AI_SEARCH_API_KEY + + # Raise exception if any required environment variable is missing + if not all([connection_name, index_name, endpoint]): + raise ValueError( + f"{cls.__name__} Missing required Azure Search environment variables" + ) + + return cls( + connection_name=connection_name, + index_name=index_name, + endpoint=endpoint, + api_key=api_key, + ) diff --git a/src/backend/af/magentic_agents/proxy_agent.py b/src/backend/af/magentic_agents/proxy_agent.py index 02cd90b79..c7fd4a38c 100644 --- a/src/backend/af/magentic_agents/proxy_agent.py +++ b/src/backend/af/magentic_agents/proxy_agent.py @@ -1,373 +1,253 @@ -# Copyright (c) Microsoft. All rights reserved. -"""Proxy agent that prompts for human clarification.""" +""" +ProxyAgentAF: Human clarification proxy implemented on agent_framework primitives. + +Responsibilities: +- Request clarification from a human via websocket +- Await response (with timeout + cancellation handling via orchestration_config) +- Yield ChatResponseUpdate objects compatible with agent_framework streaming loops +""" + +from __future__ import annotations import asyncio import logging import time import uuid -from collections.abc import AsyncIterable -from typing import AsyncIterator, Optional - -from pydantic import Field -from semantic_kernel.agents import ( # pylint: disable=no-name-in-module - AgentResponseItem, - AgentThread, +from dataclasses import dataclass, field +from typing import AsyncIterator, List, Optional + +from agent_framework import ( + ChatResponseUpdate, + Role, + TextContent, + UsageContent, + UsageDetails, ) -from semantic_kernel.agents.agent import Agent -from semantic_kernel.contents import ( - AuthorRole, - ChatMessageContent, - StreamingChatMessageContent, +from af.config.settings import connection_config, orchestration_config +from af.models.messages import ( + UserClarificationRequest, + UserClarificationResponse, + TimeoutNotification, + WebsocketMessageType, ) -from semantic_kernel.contents.chat_history import ChatHistory -from semantic_kernel.contents.history_reducer.chat_history_reducer import ( - ChatHistoryReducer, -) -from semantic_kernel.exceptions.agent_exceptions import AgentThreadOperationException -from typing_extensions import override -from v3.callbacks.response_handlers import (agent_response_callback, - streaming_agent_response_callback) -from v3.config.settings import connection_config, orchestration_config -from v3.models.messages import (UserClarificationRequest, - UserClarificationResponse, WebsocketMessageType) -# Initialize logger for the module logger = logging.getLogger(__name__) -class DummyAgentThread(AgentThread): - """Dummy thread implementation for proxy agent.""" - - def __init__( - self, chat_history: ChatHistory | None = None, thread_id: str | None = None - ): - super().__init__() - self._chat_history = chat_history if chat_history is not None else ChatHistory() - self._id: str = thread_id or f"thread_{uuid.uuid4().hex}" - self._is_deleted = False - self.logger = logging.getLogger(__name__) - - @override - async def _create(self) -> str: - """Starts the thread and returns its ID.""" - return self._id - - @override - async def _delete(self) -> None: - """Ends the current thread.""" - self._chat_history.clear() - - @override - async def _on_new_message(self, new_message: str | ChatMessageContent) -> None: - """Called when a new message has been contributed to the chat.""" - if isinstance(new_message, str): - new_message = ChatMessageContent(role=AuthorRole.USER, content=new_message) - - if ( - not new_message.metadata - or "thread_id" not in new_message.metadata - or new_message.metadata["thread_id"] != self._id - ): - self._chat_history.add_message(new_message) - - async def get_messages(self) -> AsyncIterable[ChatMessageContent]: - """Retrieve the current chat history. - - Returns: - An async iterable of ChatMessageContent. - """ - if self._is_deleted: - raise AgentThreadOperationException( - "Cannot retrieve chat history, since the thread has been deleted." - ) - if self._id is None: - await self.create() - for message in self._chat_history.messages: - yield message - - async def reduce(self) -> ChatHistory | None: - """Reduce the chat history to a smaller size.""" - if self._id is None: - raise AgentThreadOperationException( - "Cannot reduce chat history, since the thread is not currently active." - ) - if not isinstance(self._chat_history, ChatHistoryReducer): - return None - return await self._chat_history.reduce() - - -class ProxyAgentResponseItem: - """Response item wrapper for proxy agent responses.""" - - def __init__(self, message: ChatMessageContent, thread: AgentThread): - self.message = message - self.thread = thread - self.logger = logging.getLogger(__name__) - - -class ProxyAgent(Agent): - """Simple proxy agent that prompts for human clarification.""" +# --------------------------------------------------------------------------- +# Internal conversation structure (minimal alternative to SK AgentThread) +# --------------------------------------------------------------------------- - # Declare as Pydantic field - user_id: str = Field( - default=None, description="User ID for WebSocket messaging" - ) +@dataclass +class ProxyConversation: + conversation_id: str = field(default_factory=lambda: f"proxy_{uuid.uuid4().hex}") + messages: List[str] = field(default_factory=list) - def __init__(self, user_id: str, **kwargs): - # Get user_id from parameter, fallback to empty string - effective_user_id = user_id or "" - super().__init__( - name="ProxyAgent", - description="Call this agent when you need to clarify requests by asking the human user for more information. Ask it for more details about any unclear requirements, missing information, or if you need the user to elaborate on any aspect of the task.", - user_id=effective_user_id, - **kwargs, - ) - self.instructions = "" - - def _create_message_content( - self, content: str, thread_id: str = None - ) -> ChatMessageContent: - """Create a ChatMessageContent with proper metadata.""" - return ChatMessageContent( - role=AuthorRole.ASSISTANT, - content=content, - name=self.name, - metadata={"thread_id": thread_id} if thread_id else {}, - ) + def add(self, content: str) -> None: + self.messages.append(content) - async def _trigger_response_callbacks(self, message_content: ChatMessageContent): - """Manually trigger the same response callbacks used by other agents.""" - # Get current user_id dynamically instead of using stored value - current_user = self.user_id or "" - # Trigger the standard agent response callback - agent_response_callback(message_content, current_user) +# --------------------------------------------------------------------------- +# Proxy Agent AF +# --------------------------------------------------------------------------- - async def _trigger_streaming_callbacks(self, content: str, is_final: bool = False): - """Manually trigger streaming callbacks for real-time updates.""" - # Get current user_id dynamically instead of using stored value - current_user = self.user_id or "" - streaming_message = StreamingChatMessageContent( - role=AuthorRole.ASSISTANT, content=content, name=self.name, choice_index=0 - ) - await streaming_agent_response_callback( - streaming_message, is_final, current_user - ) +class ProxyAgent: + """ + A lightweight "agent" that mediates human clarification. + Not a model-backed agent; it orchestrates a request and emits a synthetic reply. + """ - async def invoke( - self, message: str, *, thread: AgentThread | None = None, **kwargs - ) -> AsyncIterator[ChatMessageContent]: - """Ask human user for clarification about the message.""" + def __init__( + self, + user_id: Optional[str], + name: str = "ProxyAgent", + description: str = ( + "Clarification agent. Ask this when instructions are unclear or additional " + "user details are required." + ), + timeout_seconds: Optional[int] = None, + ): + self.user_id = user_id or "" + self.name = name + self.description = description + self._timeout = timeout_seconds or orchestration_config.default_timeout + self._conversation = ProxyConversation() - thread = await self._ensure_thread_exists_with_messages( - messages=message, - thread=thread, - construct_thread=lambda: DummyAgentThread(), - expected_type=DummyAgentThread, - ) + # --------------------------- + # Public invocation interfaces + # --------------------------- - # Send clarification request via streaming callbacks - clarification_request = f"I need clarification about: {message}" + async def invoke(self, message: str) -> AsyncIterator[ChatResponseUpdate]: + """ + One-shot style: waits for human clarification, then yields a single final response update. + """ + async for update in self.invoke_stream(message): + # If caller expects only the final text, they can just collect the last update + continue + # When invoke_stream finishes, it already yielded final updates; + # this wrapper exists for parity with LLM agents returning enumerables. + return + + async def invoke_stream(self, message: str) -> AsyncIterator[ChatResponseUpdate]: + """ + Streaming version: + 1. Sends clarification request via websocket (no yield yet). + 2. Waits for human response / timeout. + 3. Yields: + - A ChatResponseUpdate with the final clarified answer (as assistant text) if received. + - A usage marker (synthetic) for downstream consistency. + """ + original_prompt = message or "" + self._conversation.add(original_prompt) - clarification_message = UserClarificationRequest( - question=clarification_request, - request_id=str(uuid.uuid4()), # Unique ID for the request + clarification_req_text = f"I need clarification about: {original_prompt}" + clarification_request = UserClarificationRequest( + question=clarification_req_text, + request_id=str(uuid.uuid4()), ) - # Send the approval request to the user's WebSocket + # Dispatch websocket event requesting clarification await connection_config.send_status_update_async( { "type": WebsocketMessageType.USER_CLARIFICATION_REQUEST, - "data": clarification_message, + "data": clarification_request, }, user_id=self.user_id, message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST, ) - # Get human input - human_response = await self._wait_for_user_clarification( - clarification_message.request_id - ) + # Await human clarification + human_response = await self._wait_for_user_clarification(clarification_request.request_id) - # Handle silent timeout/cancellation if human_response is None: - # Process was terminated silently - don't yield any response - logger.debug("Clarification process terminated silently - ending invoke") - return - - # Extract the answer from the response - answer = human_response.answer if human_response else "No additional clarification provided." - - response = f"Human clarification: {answer}" - - chat_message = self._create_message_content(response, thread.id) - - yield AgentResponseItem(message=chat_message, thread=thread) - - async def invoke_stream( - self, messages, thread=None, **kwargs - ) -> AsyncIterator[ProxyAgentResponseItem]: - """Stream version - handles thread management for orchestration.""" - - thread = await self._ensure_thread_exists_with_messages( - messages=messages, - thread=thread, - construct_thread=lambda: DummyAgentThread(), - expected_type=DummyAgentThread, - ) - - # Extract message content - if isinstance(messages, list) and messages: - message = ( - messages[-1].content - if hasattr(messages[-1], "content") - else str(messages[-1]) + # Timeout or cancellation already handled (timeout notification was sent). + logger.debug( + "ProxyAgentAF: No clarification response (timeout/cancel). Ending stream silently." ) - elif isinstance(messages, str): - message = messages - else: - message = str(messages) - - # Send clarification request via streaming callbacks - clarification_request = f"I need clarification about: {message}" - - clarification_message = UserClarificationRequest( - question=clarification_request, - request_id=str(uuid.uuid4()), # Unique ID for the request - ) - - # Send the approval request to the user's WebSocket - # The user_id will be automatically retrieved from context - await connection_config.send_status_update_async( - { - "type": WebsocketMessageType.USER_CLARIFICATION_REQUEST, - "data": clarification_message, - }, - user_id=self.user_id, - message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST, - ) - - # Get human input - replace with websocket call when available - human_response = await self._wait_for_user_clarification( - clarification_message.request_id - ) - - # Handle silent timeout/cancellation - if human_response is None: - # Process was terminated silently - don't yield any response - logger.debug("Clarification process terminated silently - ending invoke_stream") return - # Extract the answer from the response - answer = human_response.answer if human_response else "No additional clarification provided." + answer_text = ( + human_response.answer + if human_response.answer + else "No additional clarification provided." + ) + synthetic_reply = f"Human clarification: {answer_text}" + self._conversation.add(synthetic_reply) - response = f"Human clarification: {answer}" + # Yield final assistant text chunk + yield self._make_text_update(synthetic_reply, is_final=False) - chat_message = self._create_message_content(response, thread.id) + # Yield a synthetic usage update so downstream consumers can treat this like a model run + yield self._make_usage_update(token_estimate=len(synthetic_reply.split())) - yield AgentResponseItem(message=chat_message, thread=thread) + # --------------------------- + # Internal helpers + # --------------------------- async def _wait_for_user_clarification( self, request_id: str ) -> Optional[UserClarificationResponse]: """ - Wait for user clarification response using event-driven pattern with timeout handling. - - Args: - request_id: The request ID to wait for clarification - - Returns: - UserClarificationResponse: Clarification result with request ID and answer - - Raises: - asyncio.TimeoutError: If timeout is exceeded (300 seconds default) + Wraps orchestration_config.wait_for_clarification with robust timeout & cleanup. """ - # logger.info(f"Waiting for user clarification for request: {request_id}") - - # Initialize clarification as pending using the new event-driven method orchestration_config.set_clarification_pending(request_id) - try: - # Wait for clarification with timeout using the new event-driven method answer = await orchestration_config.wait_for_clarification(request_id) - - # logger.info(f"Clarification received for request {request_id}: {answer}") - return UserClarificationResponse( - request_id=request_id, - answer=answer, - ) + return UserClarificationResponse(request_id=request_id, answer=answer) except asyncio.TimeoutError: - # Enhanced timeout handling - notify user via WebSocket and cleanup - logger.debug(f"Clarification timeout for request {request_id} - notifying user and terminating process") - - # Create timeout notification message - from v3.models.messages import TimeoutNotification, WebsocketMessageType - timeout_notification = TimeoutNotification( - timeout_type="clarification", - request_id=request_id, - message=f"User clarification request timed out after {orchestration_config.default_timeout} seconds. Please try again.", - timestamp=time.time(), - timeout_duration=orchestration_config.default_timeout - ) - - # Send timeout notification to user via WebSocket - try: - await connection_config.send_status_update_async( - message=timeout_notification, - user_id=self.user_id, - message_type=WebsocketMessageType.TIMEOUT_NOTIFICATION, - ) - logger.info(f"Timeout notification sent to user {self.user_id} for clarification {request_id}") - except Exception as e: - logger.error(f"Failed to send timeout notification: {e}") - - # Clean up this specific request - orchestration_config.cleanup_clarification(request_id) - - # Return None to indicate silent termination - # The timeout naturally stops this specific wait operation without affecting other tasks + await self._notify_timeout(request_id) return None - - except KeyError as e: - # Silent error handling for invalid request IDs - logger.debug(f"Request ID not found: {e} - terminating process silently") - return None - except asyncio.CancelledError: - # Handle task cancellation gracefully - logger.debug(f"Clarification request {request_id} was cancelled") + logger.debug("ProxyAgentAF: Clarification request %s cancelled", request_id) orchestration_config.cleanup_clarification(request_id) return None - - except Exception as e: - # Silent error handling for unexpected errors - logger.debug(f"Unexpected error waiting for clarification: {e} - terminating process silently") + except KeyError: + logger.debug("ProxyAgentAF: Invalid clarification request id %s", request_id) + return None + except Exception as ex: # noqa: BLE001 + logger.debug("ProxyAgentAF: Unexpected error awaiting clarification: %s", ex) orchestration_config.cleanup_clarification(request_id) return None finally: - # Ensure cleanup happens for any incomplete requests - # This provides an additional safety net for resource cleanup - if (request_id in orchestration_config.clarifications and orchestration_config.clarifications[request_id] is None): - logger.debug(f"Final cleanup for pending clarification request {request_id}") + # Safety net cleanup if still pending with no value. + if ( + request_id in orchestration_config.clarifications + and orchestration_config.clarifications[request_id] is None + ): orchestration_config.cleanup_clarification(request_id) - async def get_response(self, chat_history, **kwargs): - """Get response from the agent - required by Agent base class.""" - # Extract the latest user message - latest_message = ( - chat_history.messages[-1].content if chat_history.messages else "" + async def _notify_timeout(self, request_id: str) -> None: + """Send a timeout notification to the client and clean up.""" + notice = TimeoutNotification( + timeout_type="clarification", + request_id=request_id, + message=( + f"User clarification request timed out after " + f"{self._timeout} seconds. Please retry." + ), + timestamp=time.time(), + timeout_duration=self._timeout, + ) + try: + await connection_config.send_status_update_async( + message=notice, + user_id=self.user_id, + message_type=WebsocketMessageType.TIMEOUT_NOTIFICATION, + ) + logger.info( + "ProxyAgentAF: Timeout notification sent (request_id=%s user=%s)", + request_id, + self.user_id, + ) + except Exception as ex: # noqa: BLE001 + logger.error("ProxyAgentAF: Failed to send timeout notification: %s", ex) + orchestration_config.cleanup_clarification(request_id) + + def _make_text_update( + self, + text: str, + is_final: bool, + ) -> ChatResponseUpdate: + """ + Build a ChatResponseUpdate containing assistant text. We treat each + emitted text as a 'delta'; downstream can concatenate if needed. + """ + return ChatResponseUpdate( + role=Role.ASSISTANT, + text=text, + contents=[TextContent(text=text)], + conversation_id=self._conversation.conversation_id, + message_id=str(uuid.uuid4()), + response_id=str(uuid.uuid4()), ) - # Use our invoke method to get the response - async for response in self.invoke(latest_message, **kwargs): - return response - - # Fallback if no response generated - return ChatMessageContent( - role=AuthorRole.ASSISTANT, content="No clarification provided." + def _make_usage_update(self, token_estimate: int) -> ChatResponseUpdate: + """ + Provide a synthetic usage update (assist in downstream finalization logic). + """ + usage = UsageContent( + UsageDetails( + input_token_count=0, + output_token_count=token_estimate, + total_token_count=token_estimate, + ) ) + return ChatResponseUpdate( + role=Role.ASSISTANT, + text="", + contents=[usage], + conversation_id=self._conversation.conversation_id, + message_id=str(uuid.uuid4()), + response_id=str(uuid.uuid4()), + ) + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- -async def create_proxy_agent(user_id: str = None): - """Factory function for human proxy agent.""" - return ProxyAgent(user_id=user_id) +async def create_proxy_agent(user_id: Optional[str] = None) -> ProxyAgent: + """ + Factory for ProxyAgentAF (mirrors previous create_proxy_agent interface). + """ + return ProxyAgent(user_id=user_id) \ No newline at end of file diff --git a/src/backend/af/magentic_agents/reasoning_agent.py b/src/backend/af/magentic_agents/reasoning_agent.py index b29b002ff..63979be72 100644 --- a/src/backend/af/magentic_agents/reasoning_agent.py +++ b/src/backend/af/magentic_agents/reasoning_agent.py @@ -1,19 +1,68 @@ +import asyncio import logging +import uuid +from dataclasses import dataclass +from typing import AsyncIterator, List, Optional -from common.config.app_config import config -from semantic_kernel import Kernel -from semantic_kernel.agents import ChatCompletionAgent # pylint: disable=E0611 -from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion -from v3.magentic_agents.common.lifecycle import MCPEnabledBase -from v3.magentic_agents.models.agent_models import MCPConfig, SearchConfig -from v3.magentic_agents.reasoning_search import ReasoningSearch -from v3.config.agent_registry import agent_registry +from agent_framework.azure import AzureAIAgentClient +from agent_framework import ( + ChatMessage, + ChatOptions, + ChatResponseUpdate, + HostedMCPTool, + Role, +) +from azure.identity.aio import DefaultAzureCredential +from azure.ai.projects.aio import AIProjectClient +from azure.search.documents import SearchClient +from azure.core.credentials import AzureKeyCredential +from af.magentic_agents.models.agent_models import MCPConfig, SearchConfig +from af.config.agent_registry import agent_registry -class ReasoningAgentTemplate(MCPEnabledBase): + +logger = logging.getLogger(__name__) + + +# ----------------------------- +# Lightweight search helper +# ----------------------------- +@dataclass +class _SearchContext: + client: SearchClient + top_k: int + + def fetch(self, query: str) -> List[str]: + docs: List[str] = [] + try: + results = self.client.search( + search_text=query, + query_type="simple", + select=["content"], + top=self.top_k, + ) + for r in results: + try: + docs.append(str(r["content"])) + except Exception: # noqa: BLE001 + continue + except Exception as ex: # noqa: BLE001 + logger.debug("Search fetch error: %s", ex) + return docs + + +class ReasoningAgentTemplate: """ - SK ChatCompletionAgent with optional MCP plugin injected as a Kernel plugin. - No Azure AI Agents client is needed here. We only need a token provider for SK. + agent_framework-based reasoning agent (replaces SK ChatCompletionAgent). + Class name preserved for backward compatibility. + + Differences vs original: + - No Semantic Kernel Kernel / ChatCompletionAgent. + - Streams agent_framework ChatResponseUpdate objects. + - Optional inline RAG (search results stuffed into instructions). + - Optional MCP tool exposure via HostedMCPTool. + + If callers relied on SK's ChatMessageContent objects, add an adapter layer. """ def __init__( @@ -22,73 +71,176 @@ def __init__( agent_description: str, agent_instructions: str, model_deployment_name: str, - azure_openai_endpoint: str, + azure_openai_endpoint: str, # kept name for compatibility; now Azure AI Project endpoint search_config: SearchConfig | None = None, mcp_config: MCPConfig | None = None, + max_search_docs: int = 3, ) -> None: - super().__init__(mcp=mcp_config) self.agent_name = agent_name self.agent_description = agent_description - self.agent_instructions = agent_instructions - self._model_deployment_name = model_deployment_name - self._openai_endpoint = azure_openai_endpoint + self.base_instructions = agent_instructions + self.model_deployment_name = model_deployment_name + self.project_endpoint = azure_openai_endpoint # reused meaning self.search_config = search_config - self.reasoning_search: ReasoningSearch | None = None - self.logger = logging.getLogger(__name__) - - def ad_token_provider(self) -> str: - credential = config.get_azure_credentials() - token = credential.get_token(config.AZURE_COGNITIVE_SERVICES) - return token.token - - async def _after_open(self) -> None: - self.kernel = Kernel() - - # Add Azure OpenAI Chat Completion service - chat = AzureChatCompletion( - deployment_name=self._model_deployment_name, - endpoint=self._openai_endpoint, - ad_token_provider=self.ad_token_provider, + self.mcp_config = mcp_config + self.max_search_docs = max_search_docs + + # Azure + client resources + self._credential: DefaultAzureCredential | None = None + self._project_client: AIProjectClient | None = None + self._client: AzureAIAgentClient | None = None + + # Optional search + self._search_ctx: _SearchContext | None = None + + self._opened = False + + # ------------- Lifecycle ------------- + async def open(self) -> "ReasoningAgentTemplate": + if self._opened: + return self + + self._credential = DefaultAzureCredential() + self._project_client = AIProjectClient( + endpoint=self.project_endpoint, + credential=self._credential, ) - self.kernel.add_service(chat) - # Initialize search capabilities - if self.search_config: - self.reasoning_search = ReasoningSearch(self.search_config) - await self.reasoning_search.initialize(self.kernel) + # Create AzureAIAgentClient (ephemeral agent will be created on first run) + self._client = AzureAIAgentClient( + project_client=self._project_client, + agent_id=None, + agent_name=self.agent_name, + model_deployment_name=self.model_deployment_name, + ) - # Inject MCP plugin into the SK kernel if available - if self.mcp_plugin: + # Optional search setup + if self.search_config and all( + [ + self.search_config.endpoint, + self.search_config.index_name, + self.search_config.api_key, + ] + ): try: - self.kernel.add_plugin(self.mcp_plugin, plugin_name="mcp_tools") - self.logger.info("Added MCP plugin") - except Exception as ex: - self.logger.exception(f"Could not add MCP plugin to kernel: {ex}") - - self._agent = ChatCompletionAgent( - kernel=self.kernel, - name=self.agent_name, - description=self.agent_description, - instructions=self.agent_instructions, - ) + sc = SearchClient( + endpoint=self.search_config.endpoint, + index_name=self.search_config.index_name, + credential=AzureKeyCredential(self.search_config.api_key), + ) + self._search_ctx = _SearchContext(client=sc, top_k=self.max_search_docs) + logger.info( + "ReasoningAgentTemplate: search index '%s' configured.", + self.search_config.index_name, + ) + except Exception as ex: # noqa: BLE001 + logger.warning("ReasoningAgentTemplate: search initialization failed: %s", ex) - # Register agent with global registry for tracking and cleanup + # Registry try: agent_registry.register_agent(self) - self.logger.info(f"πŸ“ Registered agent '{self.agent_name}' with global registry") - except Exception as registry_error: - self.logger.warning(f"⚠️ Failed to register agent '{self.agent_name}' with registry: {registry_error}") + except Exception: # noqa: BLE001 + pass + + self._opened = True + return self + + async def close(self) -> None: + if not self._opened: + return + try: + if self._client: + await self._client.close() + except Exception: # noqa: BLE001 + pass + try: + if self._credential: + await self._credential.close() + except Exception: # noqa: BLE001 + pass + try: + agent_registry.unregister_agent(self) + except Exception: # noqa: BLE001 + pass + + self._client = None + self._project_client = None + self._credential = None + self._search_ctx = None + self._opened = False + + async def __aenter__(self) -> "ReasoningAgentTemplate": + return await self.open() + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() + + # ------------- Public Invocation ------------- + async def invoke(self, message: str) -> AsyncIterator[ChatResponseUpdate]: + """ + Mirrors old streaming interface: + Yields ChatResponseUpdate objects (instead of SK ChatMessageContent). + Consumers expecting SK types should translate here. + """ + async for update in self._invoke_stream_internal(message): + yield update + + # ------------- Internal streaming logic ------------- + async def _invoke_stream_internal(self, prompt: str) -> AsyncIterator[ChatResponseUpdate]: + if not self._opened or not self._client: + raise RuntimeError("Agent not opened. Call open().") + + # Build instructions with optional search + instructions = self.base_instructions + if self._search_ctx and prompt.strip(): + docs = await self._fetch_docs_async(prompt) + if docs: + joined = "\n\n".join(f"[Doc {i+1}] {d}" for i, d in enumerate(docs)) + instructions = ( + f"{instructions}\n\nRelevant reference documents:\n{joined}\n\n" + "Use them only if they help answer the question." + ) + + tools = [] + if self.mcp_config: + tools.append( + HostedMCPTool( + name=self.mcp_config.name, + description=self.mcp_config.description, + server_label=self.mcp_config.name.replace(" ", "_"), + ) + ) + + chat_options = ChatOptions( + model_id=self.model_deployment_name, + tools=tools if tools else None, + tool_choice="auto" if tools else "none", + temperature=0.7, + allow_multiple_tool_calls=True, + ) + + messages = [ChatMessage(role=Role.USER, text=prompt)] + + async for update in self._client.get_streaming_response( + messages=messages, + chat_options=chat_options, + instructions=instructions, + ): + yield update - async def invoke(self, message: str): - """Invoke the agent with a message.""" - if not self._agent: - raise RuntimeError("Agent not initialized. Call open() first.") + async def _fetch_docs_async(self, query: str) -> List[str]: + if not self._search_ctx: + return [] + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: self._search_ctx.fetch(query)) - async for response in self._agent.invoke(message): - yield response + # ------------- Convenience ------------- + @property + def client(self) -> AzureAIAgentClient | None: + return self._client -# Backward‑compatible factory +# Factory (name preserved) async def create_reasoning_agent( agent_name: str, agent_description: str, @@ -108,4 +260,4 @@ async def create_reasoning_agent( mcp_config=mcp_config, ) await agent.open() - return agent + return agent \ No newline at end of file diff --git a/src/backend/af/magentic_agents/reasoning_search.py b/src/backend/af/magentic_agents/reasoning_search.py index 7f944e7f5..8912e0209 100644 --- a/src/backend/af/magentic_agents/reasoning_search.py +++ b/src/backend/af/magentic_agents/reasoning_search.py @@ -1,93 +1,183 @@ """ -RAG search capabilities for ReasoningAgentTemplate using AzureAISearchCollection. -Based on Semantic Kernel text search patterns. +Azure AI Search integration for reasoning agents (no Semantic Kernel dependency). + +This module provides: +- ReasoningSearch: lightweight wrapper around Azure Cognitive Search (Azure AI Search) +- Async initialization and async search with executor offloading +- Clean, SK-free interface for use with agent_framework-based agents + +Design goals: +- No semantic_kernel imports +- Fast to call from other async agent components +- Graceful degradation if configuration is incomplete """ +from __future__ import annotations + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional + from azure.core.credentials import AzureKeyCredential from azure.search.documents import SearchClient -from semantic_kernel import Kernel -from semantic_kernel.functions import kernel_function -from v3.magentic_agents.models.agent_models import SearchConfig +from af.magentic_agents.models.agent_models import SearchConfig class ReasoningSearch: - """Handles Azure AI Search integration for reasoning agents.""" - - def __init__(self, search_config: SearchConfig | None = None): + """ + Handles Azure AI Search (Cognitive Search) queries for retrieval / RAG augmentation. + """ + + def __init__( + self, + search_config: Optional[SearchConfig] = None, + *, + max_executor_workers: int = 4, + ) -> None: self.search_config = search_config - self.search_client: SearchClient | None = None + self.search_client: Optional[SearchClient] = None + self._executor: Optional[ThreadPoolExecutor] = None + self._max_workers = max_executor_workers + self._initialized = False + + async def initialize(self) -> bool: + """ + Initialize the search client. Safe to call multiple times. + Returns: + bool: True if initialized, False if config missing or failed. + """ + if self._initialized: + return True - async def initialize(self, kernel: Kernel) -> bool: - """Initialize the search collection with embeddings and add it to the kernel.""" if ( not self.search_config or not self.search_config.endpoint or not self.search_config.index_name + or not self.search_config.api_key ): - print("Search configuration not available") + # Incomplete config => treat as disabled return False try: - self.search_client = SearchClient( endpoint=self.search_config.endpoint, credential=AzureKeyCredential(self.search_config.api_key), index_name=self.search_config.index_name, ) - - # Add this class as a plugin so the agent can call search_documents - kernel.add_plugin(self, plugin_name="knowledge_search") - - print( - f"Added Azure AI Search plugin for index: {self.search_config.index_name}" + # Dedicated executor for blocking search calls + self._executor = ThreadPoolExecutor( + max_workers=self._max_workers, + thread_name_prefix="reasoning_search", ) + self._initialized = True return True - - except Exception as ex: - print(f"Could not initialize Azure AI Search: {ex}") + except Exception: + # Swallow initialization errors (callers can check is_available) + self.search_client = None + self._initialized = False return False - @kernel_function( - name="search_documents", - description="Search the knowledge base for relevant documents and information. Use this when you need to find specific information from internal documents or data.", - ) - async def search_documents(self, query: str, limit: str = "3") -> str: - """Search function that the agent can invoke to find relevant documents.""" - if not self.search_client: - return "Search service is not available." + def is_available(self) -> bool: + """Return True if search is properly initialized and usable.""" + return self._initialized and self.search_client is not None - try: - limit_int = int(limit) - search_results = [] - - results = self.search_client.search( - search_text=query, - query_type="simple", - select=["content"], - top=limit_int, - ) + async def search_documents(self, query: str, limit: int = 3) -> List[str]: + """ + Perform a simple full‑text search and return the 'content' field of matching docs. + + Args: + query: Natural language or keyword query. + limit: Max number of documents. - for result in results: - search_results.append(f"content: {result['content']}") + Returns: + List of strings (each a document content snippet). Empty if none or unavailable. + """ + if not self.is_available(): + return [] - if not search_results: - return f"No relevant documents found for query: '{query}'" + limit = max(1, min(limit, 50)) # basic safety bounds - return search_results + loop = asyncio.get_running_loop() + try: + return await loop.run_in_executor( + self._executor, + lambda: self._run_search_sync(query=query, limit=limit), + ) + except Exception: + return [] - except Exception as ex: - return f"Search failed: {str(ex)}" + async def search_raw(self, query: str, limit: int = 3): + """ + Raw search returning native SDK result iterator (materialized to list). + Provided for more advanced callers needing metadata. - def is_available(self) -> bool: - """Check if search functionality is available.""" - return self.search_client is not None + Returns: + list of raw SDK result objects (dict-like). + """ + if not self.is_available(): + return [] + limit = max(1, min(limit, 50)) -# Simple factory function + loop = asyncio.get_running_loop() + try: + return await loop.run_in_executor( + self._executor, lambda: self._run_search_sync(query, limit, raw=True) + ) + except Exception: + return [] + + def _run_search_sync(self, query: str, limit: int, raw: bool = False): + """ + Internal synchronous search (executed inside ThreadPoolExecutor). + """ + if not self.search_client: + return [] if not raw else [] + + results_iter = self.search_client.search( + search_text=query, + query_type="simple", + select=["content"], + top=limit, + ) + + contents: List[str] = [] + raw_items: List = [] + for item in results_iter: + try: + if raw: + raw_items.append(item) + else: + contents.append(f"{item['content']}") + except Exception: + continue + + return raw_items if raw else contents + + async def close(self) -> None: + """ + Close internal resources (executor). Idempotent. + """ + if self._executor: + self._executor.shutdown(wait=False, cancel_futures=True) + self._executor = None + self.search_client = None + self._initialized = False + + +# Factory (keeps old name, but no 'kernel' parameter needed anymore) async def create_reasoning_search( - kernel: Kernel, search_config: SearchConfig | None + search_config: Optional[SearchConfig], ) -> ReasoningSearch: - """Create and initialize a ReasoningSearch instance.""" + """ + Factory to create and initialize a ReasoningSearch instance. + + Args: + search_config: Search configuration (may be None to produce a no-op instance) + + Returns: + Initialized ReasoningSearch (is_available() indicates readiness). + """ search = ReasoningSearch(search_config) - await search.initialize(kernel) - return search + await search.initialize() + return search \ No newline at end of file From f94c27d2fcacdad9c0ec016c6502716b9521414d Mon Sep 17 00:00:00 2001 From: Francia Riesco Date: Tue, 21 Oct 2025 10:36:48 -0400 Subject: [PATCH 7/9] Refactor Foundry agent and factory imports, improve logic Updated import paths from 'v3' to 'af' for consistency and modularity. Improved error handling, logging, and tool/resource collection logic in FoundryAgentTemplate. Cleaned up unused Bing references and clarified code interpreter and MCP tool handling. These changes enhance maintainability and agent lifecycle management. --- .../af/magentic_agents/foundry_agent.py | 232 +++++++++--------- .../magentic_agents/magentic_agent_factory.py | 8 +- 2 files changed, 118 insertions(+), 122 deletions(-) diff --git a/src/backend/af/magentic_agents/foundry_agent.py b/src/backend/af/magentic_agents/foundry_agent.py index 30746e847..eb81ea5dc 100644 --- a/src/backend/af/magentic_agents/foundry_agent.py +++ b/src/backend/af/magentic_agents/foundry_agent.py @@ -1,22 +1,22 @@ -"""Agent template for building foundry agents with Azure AI Search, Bing, and MCP plugins (agent_framework version).""" +"""Agent template for building Foundry agents with Azure AI Search, optional MCP tool, and Code Interpreter (agent_framework version).""" import logging from typing import List, Optional from azure.ai.agents.models import Agent, AzureAISearchTool, CodeInterpreterToolDefinition from agent_framework.azure import AzureAIAgentClient -from agent_framework import ChatMessage, Role, ChatOptions, HostedMCPTool # HostedMCPTool for MCP plugin mapping +from agent_framework import ChatMessage, Role, ChatOptions, HostedMCPTool -from v3.magentic_agents.common.lifecycle import AzureAgentBase -from v3.magentic_agents.models.agent_models import MCPConfig, SearchConfig -from v3.config.agent_registry import agent_registry +from af.magentic_agents.common.lifecycle import AzureAgentBase +from af.magentic_agents.models.agent_models import MCPConfig, SearchConfig +from af.config.agent_registry import agent_registry -# exception too broad warning +# Broad exception flag # pylint: disable=w0718 class FoundryAgentTemplate(AzureAgentBase): - """Agent that uses Azure AI Search (RAG) and optional MCP tools via agent_framework.""" + """Agent that uses Azure AI Search (RAG) and optional MCP tool via agent_framework.""" def __init__( self, @@ -26,7 +26,6 @@ def __init__( model_deployment_name: str, enable_code_interpreter: bool = False, mcp_config: MCPConfig | None = None, - # bing_config: BingConfig | None = None, search_config: SearchConfig | None = None, ) -> None: super().__init__(mcp=mcp_config) @@ -35,41 +34,35 @@ def __init__( self.agent_instructions = agent_instructions self.model_deployment_name = model_deployment_name self.enable_code_interpreter = enable_code_interpreter - # self.bing = bing_config self.mcp = mcp_config self.search = search_config + self._search_connection = None - self._bing_connection = None self.logger = logging.getLogger(__name__) - if self.model_deployment_name in ["o3", "o4-mini"]: - raise ValueError( - "The current version of Foundry agents does not support reasoning models." - ) + if self.model_deployment_name in {"o3", "o4-mini"}: + raise ValueError("Foundry agents do not support reasoning models in this implementation.") + # ------------------------- + # Tool construction helpers + # ------------------------- async def _make_azure_search_tool(self) -> Optional[AzureAISearchTool]: - """Create Azure AI Search tool for RAG capabilities.""" - if not all([self.client, self.search and self.search.connection_name, self.search and self.search.index_name]): - self.logger.info("Azure AI Search tool not enabled") + """Create Azure AI Search tool (RAG capability).""" + if not (self.client and self.search and self.search.connection_name and self.search.index_name): + self.logger.info("Azure AI Search tool not enabled (missing config or client).") return None try: - self._search_connection = await self.client.connections.get( - name=self.search.connection_name - ) + self._search_connection = await self.client.connections.get(name=self.search.connection_name) self.logger.info("Found Azure AI Search connection: %s", self._search_connection.id) - search_tool = AzureAISearchTool( + return AzureAISearchTool( index_connection_id=self._search_connection.id, index_name=self.search.index_name, ) - self.logger.info("Azure AI Search tool created for index: %s", self.search.index_name) - return search_tool - except Exception as ex: self.logger.error( - "Azure AI Search tool creation failed: %s | Connection name: %s | Index name: %s | " - "Ensure the connection exists in Azure AI Foundry portal.", + "Azure AI Search tool creation failed: %s | connection=%s | index=%s", ex, getattr(self.search, "connection_name", None), getattr(self.search, "index_name", None), @@ -77,43 +70,55 @@ async def _make_azure_search_tool(self) -> Optional[AzureAISearchTool]: return None async def _collect_tools_and_resources(self) -> tuple[List, dict]: - """Collect all available tools and tool_resources to embed in persistent agent definition.""" + """Collect tool definitions + tool_resources for agent definition creation.""" tools: List = [] tool_resources: dict = {} + # Search tool if self.search and self.search.connection_name and self.search.index_name: search_tool = await self._make_azure_search_tool() if search_tool: tools.extend(search_tool.definitions) tool_resources = search_tool.resources self.logger.info( - "Added Azure AI Search tools: %d tool definitions", len(search_tool.definitions) + "Added %d Azure AI Search tool definitions.", + len(search_tool.definitions), ) else: - self.logger.error("Azure AI Search tool not configured properly") + self.logger.warning("Azure AI Search tool not configured properly.") + # Code Interpreter if self.enable_code_interpreter: try: tools.append(CodeInterpreterToolDefinition()) - self.logger.info("Added Code Interpreter tool") + self.logger.info("Added Code Interpreter tool definition.") except ImportError as ie: - self.logger.error("Code Interpreter tool requires additional dependencies: %s", ie) + self.logger.error("Code Interpreter dependency missing: %s", ie) - self.logger.info("Total tools configured in definition: %d", len(tools)) + self.logger.info("Total tool definitions collected: %d", len(tools)) return tools, tool_resources + # ------------------------- + # Agent lifecycle override + # ------------------------- async def _after_open(self) -> None: - """Build or reuse the Azure AI agent definition; create agent_framework client.""" + """Create or reuse Azure AI agent definition and wrap with AzureAIAgentClient.""" definition = await self._get_azure_ai_agent_definition(self.agent_name) if definition is not None: - connection_compatible = await self._check_connection_compatibility(definition) - if not connection_compatible: - await self.client.agents.delete_agent(definition.id) - self.logger.info( - "Existing agent '%s' used incompatible connection. Creating new definition.", - self.agent_name, - ) + if not await self._check_connection_compatibility(definition): + try: + await self.client.agents.delete_agent(definition.id) + self.logger.info( + "Deleted incompatible existing agent '%s'; will recreate with new connection settings.", + self.agent_name, + ) + except Exception as ex: + self.logger.warning( + "Failed deleting incompatible agent '%s': %s (will still recreate).", + self.agent_name, + ex, + ) definition = None if definition is None: @@ -126,138 +131,128 @@ async def _after_open(self) -> None: tools=tools, tool_resources=tool_resources, ) + self.logger.info("Created new Azure AI agent definition '%s'", self.agent_name) + # Instantiate persistent AzureAIAgentClient bound to existing agent_id try: - # Wrap existing agent definition with agent_framework client (persistent agent mode) self._agent = AzureAIAgentClient( project_client=self.client, agent_id=str(definition.id), agent_name=self.agent_name, - thread_id=None, # created dynamically if omitted during invocation ) except Exception as ex: self.logger.error("Failed to initialize AzureAIAgentClient: %s", ex) raise - # Register with global registry + # Register agent globally try: agent_registry.register_agent(self) - self.logger.info("πŸ“ Registered agent '%s' with global registry", self.agent_name) - except Exception as registry_error: - self.logger.warning( - "⚠️ Failed to register agent '%s' with registry: %s", self.agent_name, registry_error - ) - - async def fetch_run_details(self, thread_id: str, run_id: str) -> None: - """Fetch and log run details on failure for diagnostics.""" - try: - run = await self.client.agents.runs.get(thread=thread_id, run=run_id) - self.logger.error( - "Run failure details | status=%s | id=%s | last_error=%s | usage=%s", - getattr(run, "status", None), - run_id, - getattr(run, "last_error", None), - getattr(run, "usage", None), - ) - except Exception as ex: - self.logger.error("Could not fetch run details: %s", ex) + self.logger.info("Registered agent '%s' in global registry.", self.agent_name) + except Exception as reg_ex: + self.logger.warning("Could not register agent '%s': %s", self.agent_name, reg_ex) + # ------------------------- + # Definition compatibility + # ------------------------- async def _check_connection_compatibility(self, existing_definition: Agent) -> bool: - """Ensure existing agent definition's Azure AI Search connection matches current configuration.""" + """Verify existing Azure AI Search connection matches current config.""" try: - if not self.search or not self.search.connection_name: - self.logger.info("No search configuration provided; treating existing definition as compatible.") + if not (self.search and self.search.connection_name): + self.logger.info("No search config provided; assuming compatibility.") return True - if not getattr(existing_definition, "tool_resources", None): - self.logger.info("Existing definition lacks tool resources.") - return not self.search.connection_name - - azure_ai_search_resources = existing_definition.tool_resources.get("azure_ai_search", {}) - if not azure_ai_search_resources: - self.logger.info("Existing definition has no Azure AI Search resources.") + tool_resources = getattr(existing_definition, "tool_resources", None) + if not tool_resources: + self.logger.info("Existing agent has no tool resources; incompatible with search requirement.") return False - indexes = azure_ai_search_resources.get("indexes", []) + azure_search = tool_resources.get("azure_ai_search", {}) + indexes = azure_search.get("indexes", []) if not indexes: - self.logger.info("Existing definition search resources contain no indexes.") + self.logger.info("Existing agent has no Azure AI Search indexes; incompatible.") return False - existing_connection_id = indexes[0].get("index_connection_id") - if not existing_connection_id: - self.logger.info("Existing definition missing connection ID.") + existing_conn_id = indexes[0].get("index_connection_id") + if not existing_conn_id: + self.logger.info("Existing agent missing index_connection_id; incompatible.") return False current_connection = await self.client.connections.get(name=self.search.connection_name) - current_connection_id = current_connection.id - compatible = existing_connection_id == current_connection_id - - if compatible: - self.logger.info("Connection compatible: %s", existing_connection_id) + same = existing_conn_id == current_connection.id + if same: + self.logger.info("Search connection compatible: %s", existing_conn_id) else: self.logger.info( - "Connection mismatch: existing %s vs current %s", - existing_connection_id, - current_connection_id, + "Search connection mismatch: existing=%s current=%s", + existing_conn_id, + current_connection.id, ) - return compatible + return same except Exception as ex: - self.logger.error("Error checking connection compatibility: %s", ex) + self.logger.error("Error during connection compatibility check: %s", ex) return False async def _get_azure_ai_agent_definition(self, agent_name: str) -> Agent | None: - """Retrieve an existing Azure AI Agent definition by name if present.""" + """Return existing agent definition by name or None.""" try: - agent_id = None - agent_list = self.client.agents.list_agents() - async for agent in agent_list: + async for agent in self.client.agents.list_agents(): if agent.name == agent_name: - agent_id = agent.id - break - if agent_id is not None: - self.logger.info("Found existing agent definition with ID %s", agent_id) - return await self.client.agents.get_agent(agent_id) + self.logger.info("Found existing agent '%s' (id=%s).", agent_name, agent.id) + return await self.client.agents.get_agent(agent.id) return None except Exception as e: if "ResourceNotFound" in str(e) or "404" in str(e): - self.logger.info("Agent '%s' not found; will create new definition.", agent_name) + self.logger.info("Agent '%s' not found; will create new.", agent_name) else: self.logger.warning( - "Unexpected error retrieving agent '%s': %s. Proceeding to create new definition.", + "Unexpected error listing agent '%s': %s; will attempt creation.", agent_name, e, ) return None + # ------------------------- + # Diagnostics helper + # ------------------------- + async def fetch_run_details(self, thread_id: str, run_id: str) -> None: + """Log run diagnostics for a failed run.""" + try: + run = await self.client.agents.runs.get(thread=thread_id, run=run_id) + self.logger.error( + "Run failure | status=%s | id=%s | last_error=%s | usage=%s", + getattr(run, "status", None), + run_id, + getattr(run, "last_error", None), + getattr(run, "usage", None), + ) + except Exception as ex: + self.logger.error("Failed fetching run details (thread=%s run=%s): %s", thread_id, run_id, ex) + + # ------------------------- + # Invocation (streaming) + # ------------------------- async def invoke(self, prompt: str): """ Stream model output for a prompt. - Yields agent_framework ChatResponseUpdate objects: - - update.text for incremental text - - update.contents for tool calls / usage events + Yields ChatResponseUpdate objects: + - update.text for incremental text + - update.contents for tool calls / usage events """ - if not hasattr(self, "_agent") or self._agent is None: + if not self._agent: raise RuntimeError("Agent client not initialized; call open() first.") messages = [ChatMessage(role=Role.USER, text=prompt)] tools = [] - # Map MCP plugin (if any) to HostedMCPTool for runtime tool calling - if self.mcp_plugin: - # Minimal HostedMCPTool; advanced mapping (approval modes, categories) can be added later. - tools.append( - HostedMCPTool( - name=self.mcp_plugin.name, - server_label=self.mcp_plugin.name.replace(" ", "_"), - description=getattr(self.mcp_plugin, "description", ""), - ) - ) + # Use mcp_tool prepared in AzureAgentBase + if self.mcp_tool and isinstance(self.mcp_tool, HostedMCPTool): + tools.append(self.mcp_tool) chat_options = ChatOptions( model_id=self.model_deployment_name, tools=tools if tools else None, - tool_choice="auto", + tool_choice="auto" if tools else "none", allow_multiple_tool_calls=True, temperature=0.7, ) @@ -270,16 +265,18 @@ async def invoke(self, prompt: str): yield update +# ------------------------- +# Factory +# ------------------------- async def create_foundry_agent( agent_name: str, agent_description: str, agent_instructions: str, model_deployment_name: str, - mcp_config: MCPConfig, - # bing_config: BingConfig, - search_config: SearchConfig, + mcp_config: MCPConfig | None, + search_config: SearchConfig | None, ) -> FoundryAgentTemplate: - """Factory function to create and open a FoundryAgentTemplate (agent_framework version).""" + """Factory to create and open a FoundryAgentTemplate (agent_framework version).""" agent = FoundryAgentTemplate( agent_name=agent_name, agent_description=agent_description, @@ -287,7 +284,6 @@ async def create_foundry_agent( model_deployment_name=model_deployment_name, enable_code_interpreter=True, mcp_config=mcp_config, - # bing_config=bing_config, search_config=search_config, ) await agent.open() diff --git a/src/backend/af/magentic_agents/magentic_agent_factory.py b/src/backend/af/magentic_agents/magentic_agent_factory.py index ed74e89be..5ec3b57ce 100644 --- a/src/backend/af/magentic_agents/magentic_agent_factory.py +++ b/src/backend/af/magentic_agents/magentic_agent_factory.py @@ -8,13 +8,13 @@ from common.config.app_config import config from common.models.messages_kernel import TeamConfiguration -from v3.magentic_agents.foundry_agent import FoundryAgentTemplate -from v3.magentic_agents.models.agent_models import MCPConfig, SearchConfig +from af.magentic_agents.foundry_agent import FoundryAgentTemplate +from af.magentic_agents.models.agent_models import MCPConfig, SearchConfig # from v3.magentic_agents.models.agent_models import (BingConfig, MCPConfig, # SearchConfig) -from v3.magentic_agents.proxy_agent import ProxyAgent -from v3.magentic_agents.reasoning_agent import ReasoningAgentTemplate +from af.magentic_agents.proxy_agent import ProxyAgent +from af.magentic_agents.reasoning_agent import ReasoningAgentTemplate class UnsupportedModelError(Exception): From 899e7d633e06575d80933bb0844b39bcb3e2fd39 Mon Sep 17 00:00:00 2001 From: Francia Riesco Date: Tue, 21 Oct 2025 10:46:56 -0400 Subject: [PATCH 8/9] Update AzureAIAgentClient import paths Changed imports from 'agent_framework.azure' to 'agent_framework_azure_ai' for AzureAIAgentClient in foundry_agent.py and reasoning_agent.py. Also removed an unused import in lifecycle.py to clean up dependencies. --- src/backend/af/magentic_agents/common/lifecycle.py | 1 - src/backend/af/magentic_agents/foundry_agent.py | 2 +- src/backend/af/magentic_agents/reasoning_agent.py | 3 +-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/backend/af/magentic_agents/common/lifecycle.py b/src/backend/af/magentic_agents/common/lifecycle.py index 17d14676e..9f8c6fd0b 100644 --- a/src/backend/af/magentic_agents/common/lifecycle.py +++ b/src/backend/af/magentic_agents/common/lifecycle.py @@ -7,7 +7,6 @@ from azure.ai.projects.aio import AIProjectClient from azure.identity.aio import DefaultAzureCredential -from agent_framework.azure import AzureAIAgentClient from agent_framework import HostedMCPTool from af.magentic_agents.models.agent_models import MCPConfig diff --git a/src/backend/af/magentic_agents/foundry_agent.py b/src/backend/af/magentic_agents/foundry_agent.py index eb81ea5dc..30821a891 100644 --- a/src/backend/af/magentic_agents/foundry_agent.py +++ b/src/backend/af/magentic_agents/foundry_agent.py @@ -4,7 +4,7 @@ from typing import List, Optional from azure.ai.agents.models import Agent, AzureAISearchTool, CodeInterpreterToolDefinition -from agent_framework.azure import AzureAIAgentClient +from agent_framework_azure_ai import AzureAIAgentClient from agent_framework import ChatMessage, Role, ChatOptions, HostedMCPTool from af.magentic_agents.common.lifecycle import AzureAgentBase diff --git a/src/backend/af/magentic_agents/reasoning_agent.py b/src/backend/af/magentic_agents/reasoning_agent.py index 63979be72..52e8365a1 100644 --- a/src/backend/af/magentic_agents/reasoning_agent.py +++ b/src/backend/af/magentic_agents/reasoning_agent.py @@ -3,8 +3,7 @@ import uuid from dataclasses import dataclass from typing import AsyncIterator, List, Optional - -from agent_framework.azure import AzureAIAgentClient +from agent_framework_azure_ai import AzureAIAgentClient from agent_framework import ( ChatMessage, ChatOptions, From 1b5204463c5710c8168d749e617b80dca185645d Mon Sep 17 00:00:00 2001 From: Francia Riesco Date: Tue, 21 Oct 2025 11:57:04 -0400 Subject: [PATCH 9/9] Refactor to use agent_framework for orchestration Migrates orchestration logic and related services from semantic-kernel (v3) to agent_framework, updating imports, manager classes, and workflow construction. Updates HumanApprovalMagenticManager and OrchestrationManager to use agent_framework APIs, adapts callback handling, and ensures compatibility with new message and agent structures. Cleans up legacy code and improves maintainability for future agent_framework-based enhancements. --- src/backend/af/api/router.py | 4 +- .../af/common/services/agents_service.py | 2 +- .../af/common/services/plan_service.py | 6 +- .../af/common/services/team_service.py | 4 +- src/backend/af/config/settings.py | 4 +- .../helper/plan_to_mplan_converter.py | 2 +- .../orchestration/human_approval_manager.py | 236 ++++--------- .../af/orchestration/orchestration_manager.py | 333 +++++++++++------- 8 files changed, 279 insertions(+), 312 deletions(-) diff --git a/src/backend/af/api/router.py b/src/backend/af/api/router.py index bf654444d..83c15edcf 100644 --- a/src/backend/af/api/router.py +++ b/src/backend/af/api/router.py @@ -7,14 +7,14 @@ import af.models.messages as messages from auth.auth_utils import get_authenticated_user_details from common.database.database_factory import DatabaseFactory -from common.models.messages_kernel import ( +from common.models.messages_af import ( InputTask, Plan, PlanStatus, TeamSelectionRequest, ) from common.utils.event_utils import track_event_if_configured -from common.utils.utils_kernel import rai_success, rai_validate_team_config +from common.utils.utils_af import rai_success, rai_validate_team_config from fastapi import ( APIRouter, BackgroundTasks, diff --git a/src/backend/af/common/services/agents_service.py b/src/backend/af/common/services/agents_service.py index fc4e7fa06..e1cc49268 100644 --- a/src/backend/af/common/services/agents_service.py +++ b/src/backend/af/common/services/agents_service.py @@ -13,7 +13,7 @@ from typing import Any, Dict, List, Union from common.models.messages_kernel import TeamAgent, TeamConfiguration -from v3.common.services.team_service import TeamService +from af.common.services.team_service import TeamService class AgentsService: diff --git a/src/backend/af/common/services/plan_service.py b/src/backend/af/common/services/plan_service.py index ff7d5b30e..3edb3e3a2 100644 --- a/src/backend/af/common/services/plan_service.py +++ b/src/backend/af/common/services/plan_service.py @@ -2,16 +2,16 @@ import logging from dataclasses import asdict -import v3.models.messages as messages +import af.models.messages as messages from common.database.database_factory import DatabaseFactory -from common.models.messages_kernel import ( +from common.models.messages_af import ( AgentMessageData, AgentMessageType, AgentType, PlanStatus, ) from common.utils.event_utils import track_event_if_configured -from v3.config.settings import orchestration_config +from af.config.settings import orchestration_config logger = logging.getLogger(__name__) diff --git a/src/backend/af/common/services/team_service.py b/src/backend/af/common/services/team_service.py index 02b9cdc2a..e5fdbd65f 100644 --- a/src/backend/af/common/services/team_service.py +++ b/src/backend/af/common/services/team_service.py @@ -11,13 +11,13 @@ from azure.search.documents.indexes import SearchIndexClient from common.config.app_config import config from common.database.database_base import DatabaseBase -from common.models.messages_kernel import ( +from common.models.messages_af import ( StartingTask, TeamAgent, TeamConfiguration, UserCurrentTeam, ) -from v3.common.services.foundry_service import FoundryService +from af.common.services.foundry_service import FoundryService class TeamService: diff --git a/src/backend/af/config/settings.py b/src/backend/af/config/settings.py index 5958eb4f3..0770dc3bf 100644 --- a/src/backend/af/config/settings.py +++ b/src/backend/af/config/settings.py @@ -9,14 +9,14 @@ from typing import Dict, Optional from common.config.app_config import config -from common.models.messages_kernel import TeamConfiguration +from common.models.messages_af import TeamConfiguration from fastapi import WebSocket from semantic_kernel.agents.orchestration.magentic import MagenticOrchestration from semantic_kernel.connectors.ai.open_ai import ( AzureChatCompletion, OpenAIChatPromptExecutionSettings, ) -from v3.models.messages import MPlan, WebsocketMessageType +from af.models.messages import MPlan, WebsocketMessageType logger = logging.getLogger(__name__) diff --git a/src/backend/af/orchestration/helper/plan_to_mplan_converter.py b/src/backend/af/orchestration/helper/plan_to_mplan_converter.py index bc1dd5346..1c048b1c3 100644 --- a/src/backend/af/orchestration/helper/plan_to_mplan_converter.py +++ b/src/backend/af/orchestration/helper/plan_to_mplan_converter.py @@ -2,7 +2,7 @@ import re from typing import Iterable, List, Optional -from v3.models.models import MPlan, MStep +from af.models.models import MPlan, MStep logger = logging.getLogger(__name__) diff --git a/src/backend/af/orchestration/human_approval_manager.py b/src/backend/af/orchestration/human_approval_manager.py index bfba4befe..e7824c486 100644 --- a/src/backend/af/orchestration/human_approval_manager.py +++ b/src/backend/af/orchestration/human_approval_manager.py @@ -1,45 +1,40 @@ """ Human-in-the-loop Magentic Manager for employee onboarding orchestration. -Extends StandardMagenticManager to add approval gates before plan execution. +Extends StandardMagenticManager (agent_framework version) to add approval gates before plan execution. """ import asyncio import logging from typing import Any, Optional -import v3.models.messages as messages -from semantic_kernel.agents.orchestration.magentic import ( +import af.models.messages as messages +from agent_framework import ChatMessage, Role +from agent_framework._workflows._magentic import ( MagenticContext, - ProgressLedger, - ProgressLedgerItem, + MagenticProgressLedger as ProgressLedger, + MagenticProgressLedgerItem as ProgressLedgerItem, StandardMagenticManager, -) -from semantic_kernel.agents.orchestration.prompts._magentic_prompts import ( ORCHESTRATOR_FINAL_ANSWER_PROMPT, ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT, ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT, ) -from semantic_kernel.contents import ChatMessageContent + from v3.config.settings import connection_config, orchestration_config from v3.models.models import MPlan -from v3.orchestration.helper.plan_to_mplan_converter import \ - PlanToMPlanConverter +from v3.orchestration.helper.plan_to_mplan_converter import PlanToMPlanConverter -# Using a module level logger to avoid pydantic issues around inherited fields logger = logging.getLogger(__name__) -# Create a progress ledger that indicates the request is satisfied (task completed) class HumanApprovalMagenticManager(StandardMagenticManager): """ - Extended Magentic manager that requires human approval before executing plan steps. + Extended Magentic manager (agent_framework) that requires human approval before executing plan steps. Provides interactive approval for each step in the orchestration plan. """ - # Define Pydantic fields to avoid validation errors approval_enabled: bool = True magentic_plan: Optional[MPlan] = None - current_user_id: str + current_user_id: str # populated in __init__ def __init__(self, user_id: str, *args, **kwargs): """ @@ -50,8 +45,6 @@ def __init__(self, user_id: str, *args, **kwargs): **kwargs: Additional keyword arguments for the parent StandardMagenticManager. """ - # Remove any custom kwargs before passing to parent - plan_append = """ IMPORTANT: Never ask the user for information or clarification until all agents on the team have been asked first. @@ -69,51 +62,43 @@ def __init__(self, user_id: str, *args, **kwargs): - **DocumentCreationAgent** to draft a comprehensive onboarding plan that includes a checklist of resources and materials needed for effective onboarding. - **ProxyAgent** to review the drafted onboarding plan for clarity and completeness. - **MagenticManager** to finalize the onboarding plan and prepare it for presentation to stakeholders. - """ final_append = """ - DO NOT EVER OFFER TO HELP FURTHER IN THE FINAL ANSWER! Just provide the final answer and end with a polite closing. +DO NOT EVER OFFER TO HELP FURTHER IN THE FINAL ANSWER! Just provide the final answer and end with a polite closing. """ - # kwargs["task_ledger_facts_prompt"] = ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT + facts_append - kwargs["task_ledger_plan_prompt"] = ( - ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT + plan_append - ) - kwargs["task_ledger_plan_update_prompt"] = ( - ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT + plan_append - ) + kwargs["task_ledger_plan_prompt"] = ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT + plan_append + kwargs["task_ledger_plan_update_prompt"] = ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT + plan_append kwargs["final_answer_prompt"] = ORCHESTRATOR_FINAL_ANSWER_PROMPT + final_append + kwargs["current_user_id"] = user_id # retained for downstream usage if needed - kwargs['current_user_id'] = user_id - + self.current_user_id = user_id super().__init__(*args, **kwargs) async def plan(self, magentic_context: MagenticContext) -> Any: """ Override the plan method to create the plan first, then ask for approval before execution. + Returns the original plan ChatMessage if approved, otherwise raises. """ - # Extract task text from the context - task_text = magentic_context.task - if hasattr(task_text, "content"): - task_text = task_text.content - elif not isinstance(task_text, str): - task_text = str(task_text) + # Normalize task text + task_text = getattr(magentic_context.task, "text", str(magentic_context.task)) logger.info("\n Human-in-the-Loop Magentic Manager Creating Plan:") logger.info(" Task: %s", task_text) logger.info("-" * 60) - # First, let the parent create the actual plan logger.info(" Creating execution plan...") - plan = await super().plan(magentic_context) - logger.info(" Plan created: %s", plan) + plan_message = await super().plan(magentic_context) + logger.info(" Plan created (assistant message length=%d)", len(plan_message.text) if plan_message and plan_message.text else 0) - self.magentic_plan = self.plan_to_obj(magentic_context, self.task_ledger) + # Build structured MPlan from task ledger + if self.task_ledger is None: + raise RuntimeError("task_ledger not set after plan()") - self.magentic_plan.user_id = self.current_user_id + self.magentic_plan = self.plan_to_obj(magentic_context, self.task_ledger) + self.magentic_plan.user_id = self.current_user_id # annotate with user - # Request approval from the user before executing the plan approval_message = messages.PlanApprovalRequest( plan=self.magentic_plan, status="PENDING_APPROVAL", @@ -126,25 +111,25 @@ async def plan(self, magentic_context: MagenticContext) -> Any: else {} ), ) + try: orchestration_config.plans[self.magentic_plan.id] = self.magentic_plan - except Exception as e: + except Exception as e: # noqa: BLE001 logger.error("Error processing plan approval: %s", e) - # Send the approval request to the user's WebSocket - # The user_id will be automatically retrieved from context + # Send approval request await connection_config.send_status_update_async( message=approval_message, user_id=self.current_user_id, message_type=messages.WebsocketMessageType.PLAN_APPROVAL_REQUEST, ) - # Wait for user approval + # Await user response approval_response = await self._wait_for_user_approval(approval_message.plan.id) if approval_response and approval_response.approved: logger.info("Plan approved - proceeding with execution...") - return plan + return plan_message else: logger.debug("Plan execution cancelled by user") await connection_config.send_status_update_async( @@ -161,18 +146,16 @@ async def replan(self, magentic_context: MagenticContext) -> Any: """ Override to add websocket messages for replanning events. """ - logger.info("\nHuman-in-the-Loop Magentic Manager replanned:") - replan = await super().replan(magentic_context=magentic_context) - logger.info("Replanned: %s", replan) - return replan - - async def create_progress_ledger( - self, magentic_context: MagenticContext - ) -> ProgressLedger: - """Check for max rounds exceeded and send final message if so.""" + replan_message = await super().replan(magentic_context=magentic_context) + logger.info("Replanned message length: %d", len(replan_message.text) if replan_message and replan_message.text else 0) + return replan_message + + async def create_progress_ledger(self, magentic_context: MagenticContext) -> ProgressLedger: + """ + Check for max rounds exceeded and send final message if so, else defer to base. + """ if magentic_context.round_count >= orchestration_config.max_rounds: - # Send final message to user final_message = messages.FinalResultMessage( content="Process terminated: Maximum rounds exceeded", status="terminated", @@ -186,13 +169,9 @@ async def create_progress_ledger( ) return ProgressLedger( - is_request_satisfied=ProgressLedgerItem( - reason="Maximum rounds exceeded", answer=True - ), + is_request_satisfied=ProgressLedgerItem(reason="Maximum rounds exceeded", answer=True), is_in_loop=ProgressLedgerItem(reason="Terminating", answer=False), - is_progress_being_made=ProgressLedgerItem( - reason="Terminating", answer=False - ), + is_progress_being_made=ProgressLedgerItem(reason="Terminating", answer=False), next_speaker=ProgressLedgerItem(reason="Task complete", answer=""), instruction_or_question=ProgressLedgerItem( reason="Task complete", @@ -200,171 +179,90 @@ async def create_progress_ledger( ), ) + # Delegate to base (which creates a MagenticProgressLedger) return await super().create_progress_ledger(magentic_context) - # plan_id will not be optional in future async def _wait_for_user_approval( self, m_plan_id: Optional[str] = None ) -> Optional[messages.PlanApprovalResponse]: """ Wait for user approval response using event-driven pattern with timeout handling. - - Args: - m_plan_id: The plan ID to wait for approval - - Returns: - PlanApprovalResponse: Approval result with approved status and plan ID - - Raises: - asyncio.TimeoutError: If timeout is exceeded (300 seconds default) """ - logger.info(f"Waiting for user approval for plan: {m_plan_id}") + logger.info("Waiting for user approval for plan: %s", m_plan_id) if not m_plan_id: logger.error("No plan ID provided for approval") return messages.PlanApprovalResponse(approved=False, m_plan_id=m_plan_id) - # Initialize approval as pending using the new event-driven method orchestration_config.set_approval_pending(m_plan_id) try: - # Wait for approval with timeout using the new event-driven method approved = await orchestration_config.wait_for_approval(m_plan_id) + logger.info("Approval received for plan %s: %s", m_plan_id, approved) + return messages.PlanApprovalResponse(approved=approved, m_plan_id=m_plan_id) - logger.info(f"Approval received for plan {m_plan_id}: {approved}") - return messages.PlanApprovalResponse( - approved=approved, m_plan_id=m_plan_id - ) except asyncio.TimeoutError: - # Enhanced timeout handling - notify user via WebSocket and cleanup - logger.debug(f"Approval timeout for plan {m_plan_id} - notifying user and terminating process") + logger.debug("Approval timeout for plan %s - notifying user and terminating process", m_plan_id) - # Create timeout notification message timeout_message = messages.TimeoutNotification( timeout_type="approval", request_id=m_plan_id, message=f"Plan approval request timed out after {orchestration_config.default_timeout} seconds. Please try again.", timestamp=asyncio.get_event_loop().time(), - timeout_duration=orchestration_config.default_timeout + timeout_duration=orchestration_config.default_timeout, ) - # Send timeout notification to user via WebSocket try: await connection_config.send_status_update_async( message=timeout_message, user_id=self.current_user_id, message_type=messages.WebsocketMessageType.TIMEOUT_NOTIFICATION, ) - logger.info(f"Timeout notification sent to user {self.current_user_id} for plan {m_plan_id}") - except Exception as e: - logger.error(f"Failed to send timeout notification: {e}") + logger.info("Timeout notification sent to user %s for plan %s", self.current_user_id, m_plan_id) + except Exception as e: # noqa: BLE001 + logger.error("Failed to send timeout notification: %s", e) - # Clean up this specific request orchestration_config.cleanup_approval(m_plan_id) - - # Return None to indicate silent termination - # The timeout naturally stops this specific wait operation without affecting other tasks return None - except KeyError as e: - # Silent error handling for invalid plan IDs - logger.debug(f"Plan ID not found: {e} - terminating process silently") + except KeyError as e: # noqa: BLE001 + logger.debug("Plan ID not found: %s - terminating process silently", e) return None except asyncio.CancelledError: - # Handle task cancellation gracefully - logger.debug(f"Approval request {m_plan_id} was cancelled") + logger.debug("Approval request %s was cancelled", m_plan_id) orchestration_config.cleanup_approval(m_plan_id) return None - except Exception as e: - # Silent error handling for unexpected errors - logger.debug(f"Unexpected error waiting for approval: {e} - terminating process silently") + except Exception as e: # noqa: BLE001 + logger.debug("Unexpected error waiting for approval: %s - terminating process silently", e) orchestration_config.cleanup_approval(m_plan_id) return None + finally: - # Ensure cleanup happens for any incomplete requests - # This provides an additional safety net for resource cleanup - if (m_plan_id in orchestration_config.approvals and orchestration_config.approvals[m_plan_id] is None): - logger.debug(f"Final cleanup for pending approval plan {m_plan_id}") + if m_plan_id in orchestration_config.approvals and orchestration_config.approvals[m_plan_id] is None: + logger.debug("Final cleanup for pending approval plan %s", m_plan_id) orchestration_config.cleanup_approval(m_plan_id) - async def prepare_final_answer( - self, magentic_context: MagenticContext - ) -> ChatMessageContent: + async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: """ Override to ensure final answer is prepared after all steps are executed. """ logger.info("\n Magentic Manager - Preparing final answer...") - return await super().prepare_final_answer(magentic_context) - def plan_to_obj(self, magentic_context, ledger) -> MPlan: + def plan_to_obj(self, magentic_context: MagenticContext, ledger) -> MPlan: """Convert the generated plan from the ledger into a structured MPlan object.""" + if ledger is None or not hasattr(ledger, "plan") or not hasattr(ledger, "facts"): + raise ValueError("Invalid ledger structure; expected plan and facts attributes.") + + task_text = getattr(magentic_context.task, "text", str(magentic_context.task)) return_plan: MPlan = PlanToMPlanConverter.convert( - plan_text=ledger.plan.content, - facts=ledger.facts.content, + plan_text=getattr(ledger.plan, "text", ""), + facts=getattr(ledger.facts, "text", ""), team=list(magentic_context.participant_descriptions.keys()), - task=magentic_context.task, + task=task_text, ) - # # get the request text from the ledger - # if hasattr(magentic_context, 'task'): - # return_plan.user_request = magentic_context.task - - # return_plan.team = list(magentic_context.participant_descriptions.keys()) - - # # Get the facts content from the ledger - # if hasattr(ledger, 'facts') and ledger.facts.content: - # return_plan.facts = ledger.facts.content - - # # Get the plan / steps content from the ledger - # # Split the description into lines and clean them - # lines = [line.strip() for line in ledger.plan.content.strip().split('\n') if line.strip()] - - # found_agent = None - # prefix = None - - # for line in lines: - # found_agent = None - # prefix = None - # # log the line for troubleshooting - # logger.debug("Processing plan line: %s", line) - - # # match only lines that have bullet points - # if re.match(r'^[-β€’*]\s+', line): - # # Remove the bullet point marker - # line = re.sub(r'^[-β€’*]\s+', '', line).strip() - - # # Look for agent names in the line - - # for agent_name in return_plan.team: - # # Check if agent name appears in the line (case insensitive) - # if agent_name.lower() in line[:20].lower(): - # found_agent = agent_name - # line = line.split(agent_name, 1) - # line = line[1].strip() if len(line) > 1 else "" - # line = line.replace('*', '').strip() - # break - - # if not found_agent: - # # If no agent found, assign to ProxyAgent if available - # found_agent = "MagenticAgent" - # # If line indicates a following list of actions (e.g. "Assign **EnhancedResearchAgent** - # # to gather authoritative data on:") save and prefix to the steps - # # if line.endswith(':'): - # # line = line.replace(':', '').strip() - # # prefix = line + " " - - # # Don't create a step if action is blank - # if line.strip() != "": - # if prefix: - # line = prefix + line - # # Create the step object - # step = MStep(agent=found_agent, action=line) - - # # add the step to the plan - # return_plan.steps.append(step) # pylint: disable=E1101 - - return return_plan + return return_plan \ No newline at end of file diff --git a/src/backend/af/orchestration/orchestration_manager.py b/src/backend/af/orchestration/orchestration_manager.py index 7db458fee..c6c46e146 100644 --- a/src/backend/af/orchestration/orchestration_manager.py +++ b/src/backend/af/orchestration/orchestration_manager.py @@ -1,49 +1,95 @@ -# Copyright (c) Microsoft. All rights reserved. -"""Orchestration manager to handle the orchestration logic.""" +"""Orchestration manager (agent_framework version) handling multi-agent Magentic workflow creation and execution.""" + import asyncio import logging import uuid -from typing import List, Optional +from typing import List, Optional, Callable, Awaitable from common.config.app_config import config from common.models.messages_kernel import TeamConfiguration -from semantic_kernel.agents.orchestration.magentic import MagenticOrchestration -from semantic_kernel.agents.runtime import InProcessRuntime - -# Create custom execution settings to fix schema issues -from semantic_kernel.connectors.ai.open_ai import ( - AzureChatCompletion, OpenAIChatPromptExecutionSettings) -from semantic_kernel.contents import (ChatMessageContent, - StreamingChatMessageContent) -from v3.callbacks.response_handlers import (agent_response_callback, - streaming_agent_response_callback) -from v3.config.settings import connection_config, orchestration_config -from v3.magentic_agents.magentic_agent_factory import MagenticAgentFactory -from v3.models.messages import WebsocketMessageType -from v3.orchestration.human_approval_manager import HumanApprovalMagenticManager + +# agent_framework imports +from agent_framework import ChatMessage, Role, ChatOptions +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework._workflows import ( + MagenticBuilder, + MagenticCallbackMode, +) +from agent_framework._workflows._magentic import AgentRunResponseUpdate # type: ignore + +# Existing (legacy) callbacks expecting SK content; we'll adapt to them. +# If you've created af-native callbacks (e.g. response_handlers_af) import those instead. +from af.callbacks.response_handlers import ( + agent_response_callback, + streaming_agent_response_callback, +) +from af.config.settings import connection_config, orchestration_config +from af.models.messages import WebsocketMessageType +from af.orchestration.human_approval_manager import HumanApprovalMagenticManager class OrchestrationManager: - """Manager for handling orchestration logic.""" + """Manager for handling orchestration logic using agent_framework Magentic workflow.""" - # Class-scoped logger (always available to classmethods) logger = logging.getLogger(f"{__name__}.OrchestrationManager") def __init__(self): self.user_id: Optional[str] = None - # Optional alias (helps with autocomplete) self.logger = self.__class__.logger + # --------------------------- + # Internal callback adapters + # --------------------------- + @staticmethod + def _user_aware_agent_callback(user_id: str) -> Callable[[str, ChatMessage], Awaitable[None]]: + """Adapts agent_framework final agent ChatMessage to legacy agent_response_callback signature.""" + + async def _cb(agent_id: str, message: ChatMessage): + # Reuse existing callback expecting (ChatMessageContent, user_id). We pass text directly. + try: + agent_response_callback(message, user_id) # existing callback is sync + except Exception as e: # noqa: BLE001 + logging.getLogger(__name__).error("agent_response_callback error: %s", e) + + return _cb + + @staticmethod + def _user_aware_streaming_callback( + user_id: str, + ) -> Callable[[str, AgentRunResponseUpdate, bool], Awaitable[None]]: + """Adapts streaming updates to existing streaming handler.""" + + async def _cb(agent_id: str, update: AgentRunResponseUpdate, is_final: bool): + # Build a minimal shim object with text/content for legacy handler if needed. + # Your converted streaming handlers (response_handlers_af) should replace this eventual shim. + class _Shim: # noqa: D401 + def __init__(self, agent_id: str, update: AgentRunResponseUpdate): + self.agent_id = agent_id + self.text = getattr(update, "text", None) + self.contents = getattr(update, "contents", None) + self.role = getattr(update, "role", None) + + shim = _Shim(agent_id, update) + try: + await streaming_agent_response_callback(shim, is_final, user_id) + except Exception as e: # noqa: BLE001 + logging.getLogger(__name__).error("streaming_agent_response_callback error: %s", e) + + return _cb + + # --------------------------- + # Orchestration construction + # --------------------------- @classmethod - async def init_orchestration( - cls, agents: List, user_id: str = None - ) -> MagenticOrchestration: - """Main function to run the agents.""" - - # Custom execution settings that should work with Azure OpenAI - execution_settings = OpenAIChatPromptExecutionSettings( - max_tokens=4000, temperature=0.1 - ) + async def init_orchestration(cls, agents: List, user_id: str | None = None): + """ + Initialize a Magentic workflow with: + - Provided agents (participants) + - HumanApprovalMagenticManager as orchestrator manager + - AzureOpenAIChatClient as the underlying chat client + """ + if not user_id: + raise ValueError("user_id is required to initialize orchestration") credential = config.get_azure_credential(client_id=config.AZURE_CLIENT_ID) @@ -51,63 +97,95 @@ def get_token(): token = credential.get_token("https://cognitiveservices.azure.com/.default") return token.token - # 1. Create a Magentic orchestration with Azure OpenAI - magentic_orchestration = MagenticOrchestration( - members=agents, - manager=HumanApprovalMagenticManager( - user_id=user_id, - chat_completion_service=AzureChatCompletion( - deployment_name=config.AZURE_OPENAI_DEPLOYMENT_NAME, - endpoint=config.AZURE_OPENAI_ENDPOINT, - ad_token_provider=get_token, # Use token provider function - ), - execution_settings=execution_settings, - ), - agent_response_callback=cls._user_aware_agent_callback(user_id), - streaming_agent_response_callback=cls._user_aware_streaming_callback( - user_id - ), + # Create Azure chat client (agent_framework style) - relying on environment or explicit kwargs. + chat_client = AzureOpenAIChatClient( + endpoint=config.AZURE_OPENAI_ENDPOINT, + model_deployment_name=config.AZURE_OPENAI_DEPLOYMENT_NAME, + azure_ad_token_provider=get_token, ) - return magentic_orchestration - @staticmethod - def _user_aware_agent_callback(user_id: str): - """Factory method that creates a callback with captured user_id""" - - def callback(message: ChatMessageContent): - return agent_response_callback(message, user_id) - - return callback + # HumanApprovalMagenticManager needs the chat_client passed as 'chat_client' in its constructor signature (it subclasses StandardMagenticManager) + manager = HumanApprovalMagenticManager( + user_id=user_id, + chat_client=chat_client, + instructions=None, # optionally supply orchestrator system instructions + max_round_count=orchestration_config.max_rounds, + ) - @staticmethod - def _user_aware_streaming_callback(user_id: str): - """Factory method that creates a streaming callback with captured user_id""" - - async def callback( - streaming_message: StreamingChatMessageContent, is_final: bool - ): - return await streaming_agent_response_callback( - streaming_message, is_final, user_id - ) + # Build participant map: use each agent's name as key + participants = {} + for ag in agents: + name = getattr(ag, "agent_name", None) or getattr(ag, "name", None) + if not name: + name = f"agent_{len(participants)+1}" + participants[name] = ag + + # Assemble workflow + builder = ( + MagenticBuilder() + .participants(**participants) + .with_standard_manager(manager=manager) + ) - return callback + # Register callbacks (non-streaming manager orchestration events). We'll enable streaming agent deltas via unified mode if desired later. + # Provide direct agent + streaming callbacks (legacy adapter form). + # The builder currently surfaces unified callback OR agent callbacks; we use agent callbacks here. + # NOTE: If you want unified events instead, use builder.on_event(..., mode=MagenticCallbackMode.STREAMING). + # We'll just store callbacks by augmenting manager after build via internal surfaces. + workflow = builder.build() + # Wire agent response callbacks onto executor layer + # The built workflow exposes internal orchestrator/executor attributes; we rely on exported API for adding callbacks if present. + try: + # Attributes available: workflow._orchestrator._agent_response_callback, etc. + # Set them if not already configured (defensive). + orchestrator = getattr(workflow, "_orchestrator", None) + if orchestrator: + if getattr(orchestrator, "_agent_response_callback", None) is None: + setattr( + orchestrator, + "_agent_response_callback", + cls._user_aware_agent_callback(user_id), + ) + if getattr(orchestrator, "_streaming_agent_response_callback", None) is None: + setattr( + orchestrator, + "_streaming_agent_response_callback", + cls._user_aware_streaming_callback(user_id), + ) + except Exception as e: # noqa: BLE001 + cls.logger.warning("Could not attach callbacks to workflow orchestrator: %s", e) + + return workflow + + # --------------------------- + # Orchestration retrieval + # --------------------------- @classmethod async def get_current_or_new_orchestration( cls, user_id: str, team_config: TeamConfiguration, team_switched: bool - ) -> MagenticOrchestration: # add team_switched: bool parameter - """get existing orchestration instance.""" - current_orchestration = orchestration_config.get_current_orchestration(user_id) - if ( - current_orchestration is None or team_switched - ): # add check for team_switched flag - if current_orchestration is not None and team_switched: - for agent in current_orchestration._members: - if agent.name != "ProxyAgent": - try: - await agent.close() - except Exception as e: - cls.logger.error("Error closing agent: %s", e) + ): + """ + Return an existing workflow for the user or create a new one if: + - None exists + - Team switched flag is True + """ + current = orchestration_config.get_current_orchestration(user_id) + if current is None or team_switched: + if current is not None and team_switched: + # Close prior agents (skip ProxyAgent if desired) + for agent in getattr(current, "_participants", {}).values(): + if getattr(agent, "agent_name", getattr(agent, "name", "")) != "ProxyAgent": + close_coro = getattr(agent, "close", None) + if callable(close_coro): + try: + await close_coro() + except Exception as e: # noqa: BLE001 + cls.logger.error("Error closing agent: %s", e) + + # Build new participants via existing factory (still semantic-kernel path maybe; update separately if needed) + from v3.magentic_agents.magentic_agent_factory import MagenticAgentFactory # local import to avoid circular + factory = MagenticAgentFactory() agents = await factory.get_agents(user_id=user_id, team_config_input=team_config) orchestration_config.orchestrations[user_id] = await cls.init_orchestration( @@ -115,66 +193,57 @@ async def get_current_or_new_orchestration( ) return orchestration_config.get_current_orchestration(user_id) - async def run_orchestration(self, user_id, input_task) -> None: - """Run the orchestration with user input loop.""" - + # --------------------------- + # Execution + # --------------------------- + async def run_orchestration(self, user_id: str, input_task) -> None: + """ + Execute the Magentic workflow for the provided user and task description. + """ job_id = str(uuid.uuid4()) - - # Use the new event-driven method to set approval as pending orchestration_config.set_approval_pending(job_id) - magentic_orchestration = orchestration_config.get_current_orchestration(user_id) - - if magentic_orchestration is None: + workflow = orchestration_config.get_current_orchestration(user_id) + if workflow is None: raise ValueError("Orchestration not initialized for user.") + # Ensure manager tracks user_id try: - if hasattr(magentic_orchestration, "_manager") and hasattr( - magentic_orchestration._manager, "current_user_id" - ): - magentic_orchestration._manager.current_user_id = user_id - self.logger.debug(f"DEBUG: Set user_id on manager = {user_id}") - except Exception as e: - self.logger.error(f"Error setting user_id on manager: {e}") - - runtime = InProcessRuntime() - runtime.start() + manager = getattr(workflow, "_manager", None) + if manager and hasattr(manager, "current_user_id"): + manager.current_user_id = user_id + except Exception as e: # noqa: BLE001 + self.logger.error("Error setting user_id on manager: %s", e) + + # Build a MagenticContext-like starting message; the workflow interface likely exposes invoke(task=...) + task_text = getattr(input_task, "description", str(input_task)) + + # Provide chat options (temperature mapping from original execution_settings) + chat_options = ChatOptions( + temperature=0.1, + max_output_tokens=4000, + ) try: - - orchestration_result = await magentic_orchestration.invoke( - task=input_task.description, - runtime=runtime, - ) - - try: - self.logger.info("\nAgent responses:") - value = await orchestration_result.get() - self.logger.info(f"\nFinal result:\n{value}") - self.logger.info("=" * 50) - - # Send final result via WebSocket - await connection_config.send_status_update_async( - { - "type": WebsocketMessageType.FINAL_RESULT_MESSAGE, - "data": { - "content": str(value), - "status": "completed", - "timestamp": asyncio.get_event_loop().time(), - }, + # Invoke orchestrator; API may be workflow.invoke(task=..., chat_options=...) + result_msg: ChatMessage = await workflow.invoke(task=task_text, chat_options=chat_options) + + final_text = result_msg.text if result_msg else "" + self.logger.info("Final result:\n%s", final_text) + self.logger.info("=" * 50) + + await connection_config.send_status_update_async( + { + "type": WebsocketMessageType.FINAL_RESULT_MESSAGE, + "data": { + "content": final_text, + "status": "completed", + "timestamp": asyncio.get_event_loop().time(), }, - user_id, - message_type=WebsocketMessageType.FINAL_RESULT_MESSAGE, - ) - self.logger.info(f"Final result sent via WebSocket to user {user_id}") - except Exception as e: - self.logger.info(f"Error: {e}") - self.logger.info(f"Error type: {type(e).__name__}") - if hasattr(e, "__dict__"): - self.logger.info(f"Error attributes: {e.__dict__}") - self.logger.info("=" * 50) - - except Exception as e: - self.logger.error(f"Unexpected error: {e}") - finally: - await runtime.stop_when_idle() + }, + user_id, + message_type=WebsocketMessageType.FINAL_RESULT_MESSAGE, + ) + self.logger.info("Final result sent via WebSocket to user %s", user_id) + except Exception as e: # noqa: BLE001 + self.logger.error("Unexpected orchestration error: %s", e) \ No newline at end of file