Skip to content

Commit d9c40e2

Browse files
author
Taniya Mathur
committed
adding metering updating for limited classification
1 parent de3a6fd commit d9c40e2

File tree

2 files changed

+193
-0
lines changed

2 files changed

+193
-0
lines changed

lib/idp_common_pkg/idp_common/classification/service.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,20 @@ def _apply_limited_classification_to_all_pages(
303303
)
304304
original_document.sections = [section]
305305

306+
# Transfer metering data from classified document to original document
307+
if classified_document.metering:
308+
original_document.metering = utils.merge_metering_data(
309+
original_document.metering, classified_document.metering
310+
)
311+
312+
# Transfer errors from classification
313+
if classified_document.errors:
314+
original_document.errors.extend(classified_document.errors)
315+
316+
# Transfer metadata from classification
317+
if classified_document.metadata:
318+
original_document.metadata.update(classified_document.metadata)
319+
306320
logger.info(
307321
f"Applied classification '{primary_classification}' from {len(classified_document.pages)} pages to all {len(original_document.pages)} pages"
308322
)
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: MIT-0
3+
4+
"""
5+
Tests for metering data transfer in limited classification scenarios.
6+
"""
7+
8+
import pytest
9+
from unittest.mock import Mock
10+
11+
from idp_common.classification.service import ClassificationService
12+
from idp_common.models import Document, Page, Section, Status
13+
14+
15+
@pytest.mark.unit
16+
def test_apply_limited_classification_transfers_metering():
17+
"""Test that metering data is transferred from classified document to original document."""
18+
# Create classification service with minimal config
19+
config = {
20+
"model_id": "anthropic.claude-3-sonnet-20240229-v1:0",
21+
"classification": {"maxPagesForClassification": "2"}
22+
}
23+
service = ClassificationService(region="us-east-1", config=config)
24+
25+
# Create original document with 5 pages
26+
original_doc = Document(
27+
id="original-doc",
28+
pages={
29+
"1": Page(page_id="1"),
30+
"2": Page(page_id="2"),
31+
"3": Page(page_id="3"),
32+
"4": Page(page_id="4"),
33+
"5": Page(page_id="5"),
34+
},
35+
metering={},
36+
errors=[],
37+
metadata={},
38+
)
39+
40+
# Create classified document with 2 pages and metering data
41+
classified_doc = Document(
42+
id="classified-doc",
43+
pages={
44+
"1": Page(page_id="1", classification="invoice"),
45+
"2": Page(page_id="2", classification="invoice"),
46+
},
47+
sections=[
48+
Section(
49+
section_id="1",
50+
classification="invoice",
51+
page_ids=["1", "2"],
52+
confidence=0.9,
53+
)
54+
],
55+
metering={
56+
"Classification/bedrock/anthropic.claude-3-sonnet": {
57+
"inputTokens": 1000,
58+
"outputTokens": 100,
59+
"totalTokens": 1100,
60+
}
61+
},
62+
errors=["Classification warning"],
63+
metadata={"processing_time": 2.5},
64+
)
65+
66+
# Apply limited classification to all pages
67+
result_doc = service._apply_limited_classification_to_all_pages(
68+
original_doc, classified_doc
69+
)
70+
71+
# Verify metering data was transferred
72+
assert result_doc.metering == {
73+
"Classification/bedrock/anthropic.claude-3-sonnet": {
74+
"inputTokens": 1000,
75+
"outputTokens": 100,
76+
"totalTokens": 1100,
77+
}
78+
}
79+
80+
# Verify errors were transferred
81+
assert "Classification warning" in result_doc.errors
82+
83+
# Verify metadata was transferred
84+
assert result_doc.metadata["processing_time"] == 2.5
85+
86+
# Verify classification was applied to all pages
87+
assert len(result_doc.sections) == 1
88+
assert result_doc.sections[0].classification == "invoice"
89+
assert len(result_doc.sections[0].page_ids) == 5 # All original pages
90+
91+
# Verify all pages have the classification
92+
for page in result_doc.pages.values():
93+
assert page.classification == "invoice"
94+
assert page.confidence == 1.0
95+
96+
97+
@pytest.mark.unit
98+
def test_apply_limited_classification_merges_existing_metering():
99+
"""Test that metering data is merged with existing metering in original document."""
100+
config = {"model_id": "anthropic.claude-3-sonnet-20240229-v1:0"}
101+
service = ClassificationService(region="us-east-1", config=config)
102+
103+
# Original document with existing metering
104+
original_doc = Document(
105+
id="original-doc",
106+
pages={"1": Page(page_id="1")},
107+
metering={
108+
"OCR/textract/detect_document_text": {
109+
"pages": 1,
110+
"cost": 0.001,
111+
}
112+
},
113+
)
114+
115+
# Classified document with new metering
116+
classified_doc = Document(
117+
id="classified-doc",
118+
pages={"1": Page(page_id="1", classification="receipt")},
119+
sections=[
120+
Section(
121+
section_id="1",
122+
classification="receipt",
123+
page_ids=["1"],
124+
confidence=0.8,
125+
)
126+
],
127+
metering={
128+
"Classification/bedrock/anthropic.claude-3-sonnet": {
129+
"inputTokens": 500,
130+
"outputTokens": 50,
131+
}
132+
},
133+
)
134+
135+
# Apply limited classification
136+
result_doc = service._apply_limited_classification_to_all_pages(
137+
original_doc, classified_doc
138+
)
139+
140+
# Verify both metering entries exist
141+
assert "OCR/textract/detect_document_text" in result_doc.metering
142+
assert "Classification/bedrock/anthropic.claude-3-sonnet" in result_doc.metering
143+
assert result_doc.metering["OCR/textract/detect_document_text"]["pages"] == 1
144+
assert result_doc.metering["Classification/bedrock/anthropic.claude-3-sonnet"]["inputTokens"] == 500
145+
146+
147+
@pytest.mark.unit
148+
def test_apply_limited_classification_no_metering_data():
149+
"""Test that method works correctly when classified document has no metering data."""
150+
config = {"model_id": "anthropic.claude-3-sonnet-20240229-v1:0"}
151+
service = ClassificationService(region="us-east-1", config=config)
152+
153+
original_doc = Document(
154+
id="original-doc",
155+
pages={"1": Page(page_id="1")},
156+
)
157+
158+
classified_doc = Document(
159+
id="classified-doc",
160+
pages={"1": Page(page_id="1", classification="form")},
161+
sections=[
162+
Section(
163+
section_id="1",
164+
classification="form",
165+
page_ids=["1"],
166+
confidence=0.7,
167+
)
168+
],
169+
metering={}, # Empty metering
170+
)
171+
172+
# Should not raise an error
173+
result_doc = service._apply_limited_classification_to_all_pages(
174+
original_doc, classified_doc
175+
)
176+
177+
# Original metering should be preserved (empty in this case)
178+
assert result_doc.metering == {}
179+
assert result_doc.pages["1"].classification == "form"

0 commit comments

Comments
 (0)