Skip to content

Commit 2f6b56f

Browse files
H-1767: Write service to insert embeddings into Graph (#3846)
Co-authored-by: Ciaran Morinan <[email protected]> Co-authored-by: Ciaran Morinan <[email protected]>
1 parent a43b842 commit 2f6b56f

File tree

6 files changed

+235
-6
lines changed

6 files changed

+235
-6
lines changed

apps/hash-ai-worker-ts/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"@local/advanced-types": "0.0.0-private",
3939
"@local/hash-backend-utils": "0.0.0-private",
4040
"@local/hash-isomorphic-utils": "0.0.0-private",
41+
"@local/hash-subgraph": "0.0.0-private",
4142
"@local/status": "0.0.0-private",
4243
"@temporalio/activity": "1.8.1",
4344
"@temporalio/worker": "1.8.1",
@@ -53,7 +54,6 @@
5354
"devDependencies": {
5455
"@local/eslint-config": "0.0.0-private",
5556
"@local/hash-graph-client": "0.0.0-private",
56-
"@local/hash-subgraph": "0.0.0-private",
5757
"@local/tsconfig": "0.0.0-private",
5858
"@types/dedent": "0.7.0",
5959
"@types/dotenv-flow": "3.2.0",

apps/hash-ai-worker-ts/src/activities.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,19 @@ import type {
33
InferEntitiesCallerParams,
44
InferEntitiesReturn,
55
} from "@local/hash-isomorphic-utils/ai-inference-types";
6+
import type {
7+
BaseUrl,
8+
EntityPropertiesObject,
9+
PropertyTypeWithMetadata,
10+
} from "@local/hash-subgraph";
611
import { ApplicationFailure } from "@temporalio/activity";
12+
import { CreateEmbeddingResponse } from "openai/resources";
713

14+
import { createEmbeddings } from "./activities/embeddings";
815
import { inferEntities } from "./activities/infer-entities";
916

17+
export { createGraphActivities } from "./activities/graph";
18+
1019
export const createAiActivities = ({
1120
graphApiClient,
1221
}: {
@@ -22,4 +31,14 @@ export const createAiActivities = ({
2231

2332
return status;
2433
},
34+
35+
async createEmbeddingsActivity(params: {
36+
entityProperties: EntityPropertiesObject;
37+
propertyTypes: PropertyTypeWithMetadata[];
38+
}): Promise<{
39+
embeddings: { property?: BaseUrl; embedding: number[] }[];
40+
usage: CreateEmbeddingResponse.Usage;
41+
}> {
42+
return createEmbeddings(params);
43+
},
2544
});
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import type {
2+
BaseUrl,
3+
EntityPropertiesObject,
4+
EntityPropertyValue,
5+
PropertyTypeWithMetadata,
6+
} from "@local/hash-subgraph";
7+
import OpenAI from "openai/index";
8+
import { CreateEmbeddingResponse } from "openai/resources";
9+
import Usage = CreateEmbeddingResponse.Usage;
10+
11+
const openai = new OpenAI({
12+
apiKey: process.env.OPENAI_API_KEY,
13+
});
14+
15+
const createEmbeddingInput = (params: {
16+
propertyType: PropertyTypeWithMetadata;
17+
property: EntityPropertyValue;
18+
}): string => {
19+
return `${params.propertyType.schema.title}: ${JSON.stringify(
20+
params.property,
21+
)}`;
22+
};
23+
24+
export const createEmbeddings = async (params: {
25+
entityProperties: EntityPropertiesObject;
26+
propertyTypes: PropertyTypeWithMetadata[];
27+
}): Promise<{
28+
embeddings: { property?: BaseUrl; embedding: number[] }[];
29+
usage: Usage;
30+
}> => {
31+
if (params.propertyTypes.length === 0) {
32+
return {
33+
embeddings: [],
34+
usage: {
35+
prompt_tokens: 0,
36+
total_tokens: 0,
37+
},
38+
};
39+
}
40+
41+
// sort property types by their base url
42+
params.propertyTypes.sort((a, b) =>
43+
a.metadata.recordId.baseUrl.localeCompare(b.metadata.recordId.baseUrl),
44+
);
45+
46+
// We want to create embeddings for:
47+
// 1. Each individual '[Property Title]: [Value]' pair, and
48+
// 2. A list of all property key:value pairs
49+
//
50+
// We use the last item in the array to store the combined 'all properties' list.
51+
const propertyEmbeddings = [];
52+
let combinedEntityEmbedding = "";
53+
for (const propertyType of params.propertyTypes) {
54+
const property =
55+
params.entityProperties[propertyType.metadata.recordId.baseUrl];
56+
if (property === undefined) {
57+
// `property` could be `null` or `false` as well but that is part of the semantic meaning of the property
58+
// and should be included in the embedding.
59+
continue;
60+
}
61+
const embeddingInput = createEmbeddingInput({ propertyType, property });
62+
combinedEntityEmbedding += `${embeddingInput}\n`;
63+
propertyEmbeddings.push(embeddingInput);
64+
}
65+
66+
const response = await openai.embeddings.create({
67+
input: [...propertyEmbeddings, combinedEntityEmbedding],
68+
model: "text-embedding-ada-002",
69+
});
70+
71+
return {
72+
usage: response.usage,
73+
embeddings: response.data.map((data, idx) => ({
74+
property: params.propertyTypes[idx]?.metadata.recordId.baseUrl,
75+
embedding: data.embedding,
76+
})),
77+
};
78+
};
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import type {
2+
EntityEmbedding,
3+
EntityTypeStructuralQuery,
4+
GraphApi,
5+
} from "@local/hash-graph-client";
6+
import type {
7+
AccountId,
8+
EntityTypeRootType,
9+
PropertyTypeWithMetadata,
10+
Subgraph,
11+
} from "@local/hash-subgraph";
12+
import {
13+
getPropertyTypes,
14+
mapGraphApiSubgraphToSubgraph,
15+
} from "@local/hash-subgraph/stdlib";
16+
17+
export const createGraphActivities = ({
18+
graphApiClient,
19+
}: {
20+
graphApiClient: GraphApi;
21+
}) => ({
22+
async getEntityTypesByQuery(params: {
23+
authentication: {
24+
actorId: AccountId;
25+
};
26+
query: EntityTypeStructuralQuery;
27+
}): Promise<Subgraph<EntityTypeRootType>> {
28+
return graphApiClient
29+
.getEntityTypesByQuery(params.authentication.actorId, params.query)
30+
.then((response) => mapGraphApiSubgraphToSubgraph(response.data));
31+
},
32+
33+
async updateEntityEmbeddings(params: {
34+
authentication: {
35+
actorId: AccountId;
36+
};
37+
embeddings: EntityEmbedding[];
38+
}): Promise<void> {
39+
await graphApiClient
40+
.updateEntityEmbeddings(params.authentication.actorId, {
41+
embeddings: params.embeddings,
42+
reset: true,
43+
})
44+
.then((response) => response.data);
45+
},
46+
47+
// eslint-disable-next-line @typescript-eslint/require-await
48+
async getSubgraphPropertyTypes(params: {
49+
subgraph: Subgraph;
50+
}): Promise<PropertyTypeWithMetadata[]> {
51+
return getPropertyTypes(params.subgraph);
52+
},
53+
});

apps/hash-ai-worker-ts/src/main.ts

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import { Logger } from "@local/hash-backend-utils/logger";
77
import { NativeConnection, Worker } from "@temporalio/worker";
88
import { config } from "dotenv-flow";
99

10-
import { createAiActivities } from "./activities";
10+
import { createAiActivities, createGraphActivities } from "./activities";
1111

1212
export const monorepoRootDir = path.resolve(__dirname, "../../..");
1313

@@ -59,9 +59,14 @@ async function run() {
5959

6060
const worker = await Worker.create({
6161
...workflowOption(),
62-
activities: createAiActivities({
63-
graphApiClient,
64-
}),
62+
activities: {
63+
...createAiActivities({
64+
graphApiClient,
65+
}),
66+
...createGraphActivities({
67+
graphApiClient,
68+
}),
69+
},
6570
connection: await NativeConnection.connect({
6671
address: `${TEMPORAL_HOST}:${TEMPORAL_PORT}`,
6772
}),
Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import type { InferEntitiesCallerParams } from "@local/hash-isomorphic-utils/ai-inference-types";
2+
import type { AccountId, Entity } from "@local/hash-subgraph";
23
import { proxyActivities } from "@temporalio/workflow";
4+
import { CreateEmbeddingResponse } from "openai/resources";
35

4-
import { createAiActivities } from "./activities";
6+
import { createAiActivities, createGraphActivities } from "./activities";
57

68
const aiActivities = proxyActivities<ReturnType<typeof createAiActivities>>({
79
startToCloseTimeout: "1800 second",
@@ -10,5 +12,77 @@ const aiActivities = proxyActivities<ReturnType<typeof createAiActivities>>({
1012
},
1113
});
1214

15+
const graphActivities = proxyActivities<
16+
ReturnType<typeof createGraphActivities>
17+
>({
18+
startToCloseTimeout: "10 second",
19+
retry: {
20+
maximumAttempts: 3,
21+
},
22+
});
23+
1324
export const inferEntities = (params: InferEntitiesCallerParams) =>
1425
aiActivities.inferEntitiesActivity(params);
26+
27+
export const updateEntityEmbeddings = async (params: {
28+
authentication: {
29+
actorId: AccountId;
30+
};
31+
entity: Entity;
32+
}): Promise<CreateEmbeddingResponse.Usage> => {
33+
const subgraph = await graphActivities.getEntityTypesByQuery({
34+
authentication: params.authentication,
35+
query: {
36+
filter: {
37+
equal: [
38+
{ path: ["versionedUrl"] },
39+
{ parameter: params.entity.metadata.entityTypeId },
40+
],
41+
},
42+
graphResolveDepths: {
43+
inheritsFrom: { outgoing: 255 },
44+
constrainsValuesOn: { outgoing: 0 },
45+
constrainsPropertiesOn: { outgoing: 1 },
46+
constrainsLinksOn: { outgoing: 0 },
47+
constrainsLinkDestinationsOn: { outgoing: 0 },
48+
isOfType: { outgoing: 0 },
49+
hasLeftEntity: { incoming: 0, outgoing: 0 },
50+
hasRightEntity: { incoming: 0, outgoing: 0 },
51+
},
52+
temporalAxes: {
53+
pinned: {
54+
axis: "transactionTime",
55+
timestamp: null,
56+
},
57+
variable: {
58+
axis: "decisionTime",
59+
interval: {
60+
start: null,
61+
end: null,
62+
},
63+
},
64+
},
65+
includeDrafts: false,
66+
},
67+
});
68+
const propertyTypes = await graphActivities.getSubgraphPropertyTypes({
69+
subgraph,
70+
});
71+
72+
const generatedEmbeddings = await aiActivities.createEmbeddingsActivity({
73+
entityProperties: params.entity.properties,
74+
propertyTypes,
75+
});
76+
77+
if (generatedEmbeddings.embeddings.length > 0) {
78+
await graphActivities.updateEntityEmbeddings({
79+
authentication: params.authentication,
80+
embeddings: generatedEmbeddings.embeddings.map((embedding) => ({
81+
...embedding,
82+
entityId: params.entity.metadata.recordId.entityId,
83+
})),
84+
});
85+
}
86+
87+
return generatedEmbeddings.usage;
88+
};

0 commit comments

Comments
 (0)