diff --git a/lib/server/ai/ai_config_tracker.rb b/lib/server/ai/ai_config_tracker.rb index 2cb9337..ec79c40 100644 --- a/lib/server/ai/ai_config_tracker.rb +++ b/lib/server/ai/ai_config_tracker.rb @@ -42,13 +42,26 @@ def initialize # The AIConfigTracker class is used to track AI configuration usage. # class AIConfigTracker - attr_reader :ld_client, :config_key, :context, :variation_key, :version, :summary + attr_reader :ld_client, :config_key, :context, :variation_key, :version, :summary, :model_name, :provider_name - def initialize(ld_client:, variation_key:, config_key:, version:, context:) + # + # Initialize a new AIConfigTracker instance. + # + # @param ld_client [LDClient] The LaunchDarkly client instance + # @param variation_key [String] The variation key from the flag evaluation + # @param config_key [String] The configuration key + # @param version [Integer] The version number + # @param model_name [String] The name of the AI model being used + # @param provider_name [String] The name of the AI provider + # @param context [LDContext] The context used for the flag evaluation + # + def initialize(ld_client:, variation_key:, config_key:, version:, model_name:, provider_name:, context:) @ld_client = ld_client @variation_key = variation_key @config_key = config_key @version = version + @model_name = model_name + @provider_name = provider_name @context = context @summary = MetricSummary.new end @@ -221,7 +234,13 @@ def track_bedrock_converse_metrics(&block) end private def flag_data - { variationKey: @variation_key, configKey: @config_key, version: @version } + { + variationKey: @variation_key, + configKey: @config_key, + version: @version, + modelName: @model_name, + providerName: @provider_name, + } end private def openai_to_token_usage(usage) diff --git a/lib/server/ai/client.rb b/lib/server/ai/client.rb index 073f092..233c21c 100644 --- a/lib/server/ai/client.rb +++ b/lib/server/ai/client.rb @@ -188,6 +188,8 @@ def config(config_key, context, default_value = nil, variables = nil) variation_key: variation.dig(:_ldMeta, :variationKey) || '', config_key: config_key, version: variation.dig(:_ldMeta, :version) || 1, + model_name: model&.name || '', + provider_name: provider_config&.name || '', context: context ) diff --git a/spec/server/ai/client_spec.rb b/spec/server/ai/client_spec.rb index d5ae0dd..0d6515d 100644 --- a/spec/server/ai/client_spec.rb +++ b/spec/server/ai/client_spec.rb @@ -213,6 +213,11 @@ expect(config.provider).not_to be_nil expect(config.provider.name).to eq('fakeProvider') + expect(config.tracker).not_to be_nil + expect(config.tracker.send(:flag_data)).to include( + modelName: 'fakeModel', + providerName: 'fakeProvider' + ) end it 'interpolates context variables in messages using ldctx' do @@ -334,6 +339,11 @@ expect(config.model).to be_nil expect(config.messages).to be_nil expect(config.provider).to be_nil + expect(config.tracker).not_to be_nil + expect(config.tracker.send(:flag_data)).to include( + modelName: '', + providerName: '' + ) end end end diff --git a/spec/server/ai/config_tracker_spec.rb b/spec/server/ai/config_tracker_spec.rb index 8fed843..f2c50d0 100644 --- a/spec/server/ai/config_tracker_spec.rb +++ b/spec/server/ai/config_tracker_spec.rb @@ -27,14 +27,16 @@ end let(:context) { LaunchDarkly::LDContext.create({ key: 'user-key', kind: 'user' }) } - let(:tracker_flag_data) { { variationKey: 'test-variation', configKey: 'test-config', version: 1 } } + let(:tracker_flag_data) { { variationKey: 'test-variation', configKey: 'test-config', version: 1, modelName: 'fakeModel', providerName: 'fakeProvider' } } let(:tracker) do described_class.new( ld_client: ld_client, config_key: tracker_flag_data[:configKey], context: context, variation_key: tracker_flag_data[:variationKey], - version: tracker_flag_data[:version] + version: tracker_flag_data[:version], + model_name: 'fakeModel', + provider_name: 'fakeProvider' ) end @@ -408,4 +410,13 @@ expect(tracker.summary.time_to_first_token).to be_nil end end + + describe '#flag_data' do + it 'includes model_name and provider_name in flag data' do + expect(tracker.send(:flag_data)).to include( + modelName: 'fakeModel', + providerName: 'fakeProvider' + ) + end + end end