Skip to content

Commit f65dab5

Browse files
committed
refactor: start implementing gathering invoice
1 parent 0ab57fa commit f65dab5

File tree

10 files changed

+681
-270
lines changed

10 files changed

+681
-270
lines changed

openmeter/billing/adapter.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ type Adapter interface {
1717
InvoiceLineAdapter
1818
InvoiceSplitLineGroupAdapter
1919
InvoiceAdapter
20+
GatheringInvoiceAdapter
2021
SequenceAdapter
2122
InvoiceAppAdapter
2223
CustomerSynchronizationAdapter
@@ -74,6 +75,14 @@ type InvoiceAdapter interface {
7475
GetInvoiceOwnership(ctx context.Context, input GetInvoiceOwnershipAdapterInput) (GetOwnershipAdapterResponse, error)
7576
}
7677

78+
type GatheringInvoiceAdapter interface {
79+
CreateGatheringInvoice(ctx context.Context, input CreateGatheringInvoiceAdapterInput) (GatheringInvoice, error)
80+
UpdateGatheringInvoice(ctx context.Context, input UpdateGatheringInvoiceAdapterInput) (GatheringInvoice, error)
81+
DeleteGatheringInvoice(ctx context.Context, input DeleteGatheringInvoiceAdapterInput) error
82+
// TODO: Later once we have the union type for invoices
83+
// ListGatheringInvoices(ctx context.Context, input ListGatheringInvoicesInput) (ListGatheringInvoicesResponse, error)
84+
}
85+
7786
type InvoiceSplitLineGroupAdapter interface {
7887
CreateSplitLineGroup(ctx context.Context, input CreateSplitLineGroupAdapterInput) (SplitLineGroup, error)
7988
UpdateSplitLineGroup(ctx context.Context, input UpdateSplitLineGroupInput) (SplitLineGroup, error)
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
package billingadapter
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"time"
7+
8+
"github.com/alpacahq/alpacadecimal"
9+
"github.com/openmeterio/openmeter/openmeter/billing"
10+
"github.com/openmeterio/openmeter/openmeter/ent/db"
11+
"github.com/openmeterio/openmeter/pkg/convert"
12+
"github.com/openmeterio/openmeter/pkg/framework/entutils"
13+
"github.com/openmeterio/openmeter/pkg/models"
14+
"github.com/openmeterio/openmeter/pkg/slicesx"
15+
"github.com/openmeterio/openmeter/pkg/timeutil"
16+
"github.com/samber/lo"
17+
)
18+
19+
var _ billing.GatheringInvoiceAdapter = (*adapter)(nil)
20+
21+
func (a *adapter) CreateGatheringInvoice(ctx context.Context, input billing.CreateGatheringInvoiceAdapterInput) (billing.GatheringInvoice, error) {
22+
if err := input.Validate(); err != nil {
23+
return billing.GatheringInvoice{}, err
24+
}
25+
26+
return entutils.TransactingRepo(ctx, a, func(ctx context.Context, tx *adapter) (billing.GatheringInvoice, error) {
27+
customer := input.Customer
28+
supplier := input.MergedProfile.Supplier
29+
30+
// Clone the workflow config
31+
clonedWorkflowConfig, err := tx.createWorkflowConfig(ctx, input.Namespace, input.MergedProfile.WorkflowConfig)
32+
if err != nil {
33+
return billing.GatheringInvoice{}, fmt.Errorf("clone workflow config: %w", err)
34+
}
35+
36+
workflowConfig := mapWorkflowConfigToDB(input.MergedProfile.WorkflowConfig, clonedWorkflowConfig.ID)
37+
38+
// Force cloning of the workflow
39+
// TODO: Is this needed?
40+
workflowConfig.ID = ""
41+
42+
currentSchemaLevel, err := tx.GetInvoiceDefaultSchemaLevel(ctx)
43+
if err != nil {
44+
return billing.GatheringInvoice{}, fmt.Errorf("get invoice write schema level: %w", err)
45+
}
46+
47+
createMut := tx.db.BillingInvoice.Create().
48+
SetNamespace(input.Namespace).
49+
SetMetadata(input.Metadata).
50+
SetCurrency(input.Currency).
51+
SetStatus(billing.StandardInvoiceStatusGathering).
52+
SetSourceBillingProfileID(input.MergedProfile.ID).
53+
SetType(billing.InvoiceTypeStandard). // TODO: Migrate to GatheringInvoiceType once we have the type in the database
54+
SetNumber(input.Number).
55+
SetNillableDescription(input.Description).
56+
SetNillableCollectionAt(input.NextCollectionAt).
57+
SetSchemaLevel(currentSchemaLevel).
58+
// Customer snapshot about usage attribution fields
59+
SetCustomerID(input.Customer.ID).
60+
// TODO: Remove all below this line once we have seperate tables for gathering invoices
61+
SetBillingWorkflowConfigID(clonedWorkflowConfig.ID).
62+
SetTaxAppID(input.MergedProfile.Apps.Tax.GetID().ID).
63+
SetInvoicingAppID(input.MergedProfile.Apps.Invoicing.GetID().ID).
64+
SetPaymentAppID(input.MergedProfile.Apps.Payment.GetID().ID).
65+
// Totals
66+
SetAmount(alpacadecimal.Zero).
67+
SetChargesTotal(alpacadecimal.Zero).
68+
SetDiscountsTotal(alpacadecimal.Zero).
69+
SetTaxesTotal(alpacadecimal.Zero).
70+
SetTaxesExclusiveTotal(alpacadecimal.Zero).
71+
SetTaxesInclusiveTotal(alpacadecimal.Zero).
72+
SetTotal(alpacadecimal.Zero).
73+
// Supplier contacts
74+
SetSupplierName(supplier.Name)
75+
76+
// Customer usage attribution
77+
if usageAttr := mapCustomerUsageAttributionToDB(input.Customer); usageAttr != nil {
78+
createMut = createMut.SetCustomerUsageAttribution(usageAttr)
79+
}
80+
createMut = createMut.
81+
SetCustomerName(customer.Name)
82+
83+
newInvoice, err := createMut.Save(ctx)
84+
if err != nil {
85+
return billing.GatheringInvoice{}, err
86+
}
87+
88+
// Let's add required edges for mapping
89+
newInvoice.Edges.BillingWorkflowConfig = clonedWorkflowConfig
90+
91+
return tx.mapGatheringInvoiceFromDB(ctx, newInvoice, billing.InvoiceExpandAll)
92+
})
93+
}
94+
95+
func (a *adapter) mapGatheringInvoiceFromDB(ctx context.Context, invoice *db.BillingInvoice, expand billing.InvoiceExpand) (billing.GatheringInvoice, error) {
96+
if invoice.Status != billing.StandardInvoiceStatusGathering {
97+
return billing.GatheringInvoice{}, fmt.Errorf("invoice is not a gathering invoice [id=%s]", invoice.ID)
98+
}
99+
100+
period := timeutil.ClosedPeriod{}
101+
102+
if invoice.PeriodStart != nil && invoice.PeriodEnd != nil {
103+
period = timeutil.ClosedPeriod{
104+
From: invoice.PeriodStart.In(time.UTC),
105+
To: invoice.PeriodEnd.In(time.UTC),
106+
}
107+
}
108+
109+
res := billing.GatheringInvoice{
110+
GatheringInvoiceBase: billing.GatheringInvoiceBase{
111+
ManagedResource: models.ManagedResource{
112+
NamespacedModel: models.NamespacedModel{
113+
Namespace: invoice.Namespace,
114+
},
115+
ManagedModel: models.ManagedModel{
116+
CreatedAt: invoice.CreatedAt,
117+
UpdatedAt: invoice.UpdatedAt,
118+
DeletedAt: convert.TimePtrIn(invoice.DeletedAt, time.UTC),
119+
},
120+
ID: invoice.ID,
121+
Name: invoice.Number,
122+
Description: invoice.Description,
123+
},
124+
125+
Metadata: invoice.Metadata,
126+
Number: invoice.Number,
127+
CustomerID: invoice.CustomerID,
128+
Currency: invoice.Currency,
129+
ServicePeriod: period,
130+
NextCollectionAt: invoice.CollectionAt.In(time.UTC),
131+
SchemaLevel: invoice.SchemaLevel,
132+
},
133+
}
134+
135+
if expand.Lines {
136+
mappedLines, err := a.mapGatheringInvoiceLinesFromDB(invoice.SchemaLevel, invoice.Edges.BillingInvoiceLines)
137+
if err != nil {
138+
return billing.GatheringInvoice{}, err
139+
}
140+
141+
// TODO[later]: Implement this once we have proper union type for invoices
142+
// mappedLines, err = a.expandSplitLineHierarchy(ctx, invoice.Namespace, mappedLines)
143+
// if err != nil {
144+
// return billing.StandardInvoice{}, err
145+
// }
146+
147+
res.Lines = billing.NewGatheringInvoiceLines(mappedLines)
148+
}
149+
150+
return res, nil
151+
}
152+
153+
func (a *adapter) mapGatheringInvoiceLinesFromDB(schemaLevel int, dbLines []*db.BillingInvoiceLine) ([]billing.GatheringLine, error) {
154+
return slicesx.MapWithErr(dbLines, func(dbLine *db.BillingInvoiceLine) (billing.GatheringLine, error) {
155+
return a.mapGatheringInvoiceLineFromDB(schemaLevel, dbLine)
156+
})
157+
}
158+
159+
func (a *adapter) mapGatheringInvoiceLineFromDB(schemaLevel int, dbLine *db.BillingInvoiceLine) (billing.GatheringLine, error) {
160+
if dbLine.Type != billing.InvoiceLineTypeUsageBased {
161+
return billing.GatheringLine{}, fmt.Errorf("only usage based lines can be gathering invoice lines [line_id=%s]", dbLine.ID)
162+
}
163+
164+
ubpLine := dbLine.Edges.UsageBasedLine
165+
if ubpLine == nil {
166+
return billing.GatheringLine{}, fmt.Errorf("usage based line data is missing [line_id=%s]", dbLine.ID)
167+
}
168+
169+
line := billing.GatheringLine{
170+
ManagedResource: models.NewManagedResource(models.ManagedResourceInput{
171+
Namespace: dbLine.Namespace,
172+
ID: dbLine.ID,
173+
CreatedAt: dbLine.CreatedAt.In(time.UTC),
174+
UpdatedAt: dbLine.UpdatedAt.In(time.UTC),
175+
DeletedAt: convert.TimePtrIn(dbLine.DeletedAt, time.UTC),
176+
Name: dbLine.Name,
177+
Description: dbLine.Description,
178+
}),
179+
180+
Metadata: dbLine.Metadata,
181+
Annotations: dbLine.Annotations,
182+
InvoiceID: dbLine.InvoiceID,
183+
ManagedBy: dbLine.ManagedBy,
184+
185+
ServicePeriod: timeutil.ClosedPeriod{
186+
From: dbLine.PeriodStart.In(time.UTC),
187+
To: dbLine.PeriodEnd.In(time.UTC),
188+
},
189+
190+
SplitLineGroupID: dbLine.SplitLineGroupID,
191+
ChildUniqueReferenceID: dbLine.ChildUniqueReferenceID,
192+
193+
InvoiceAt: dbLine.InvoiceAt.In(time.UTC),
194+
195+
Currency: dbLine.Currency,
196+
197+
TaxConfig: lo.EmptyableToPtr(dbLine.TaxConfig),
198+
RateCardDiscounts: lo.FromPtr(dbLine.RatecardDiscounts),
199+
200+
UBPConfigID: ubpLine.ID,
201+
FeatureKey: lo.FromPtr(ubpLine.FeatureKey),
202+
Price: lo.FromPtr(ubpLine.Price),
203+
}
204+
205+
if dbLine.SubscriptionID != nil && dbLine.SubscriptionPhaseID != nil && dbLine.SubscriptionItemID != nil {
206+
line.Subscription = &billing.SubscriptionReference{
207+
SubscriptionID: *dbLine.SubscriptionID,
208+
PhaseID: *dbLine.SubscriptionPhaseID,
209+
ItemID: *dbLine.SubscriptionItemID,
210+
BillingPeriod: timeutil.ClosedPeriod{
211+
From: lo.FromPtr(dbLine.SubscriptionBillingPeriodFrom).In(time.UTC),
212+
To: lo.FromPtr(dbLine.SubscriptionBillingPeriodTo).In(time.UTC),
213+
},
214+
}
215+
}
216+
217+
return line, nil
218+
}

0 commit comments

Comments
 (0)