Skip to content

Commit 76431ba

Browse files
authored
fix: Persist rendering type detection results in AdaptivePlaywrightCrawler (#2987)
- closes #2899 - The PR introduces a `RecoverableState` class with the hopes of using it for all persistent components (`Statistics`, `SessionPool`, ...), similarly to the Python port, in the future - `RecoverableState` is utilized in `RenderingTypePredictor` to make its state persistent by default for better detection reliability after pauses and migrations
1 parent 9eaa226 commit 76431ba

File tree

9 files changed

+451
-16
lines changed

9 files changed

+451
-16
lines changed

packages/core/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ export * from './session_pool';
1414
export * from './storages';
1515
export * from './validators';
1616
export * from './cookie_utils';
17+
export * from './recoverable_state';
1718
export { PseudoUrl } from '@apify/pseudo_url';
1819
export { Dictionary, Awaitable, Constructor, StorageClient, Cookie, QueueOperationInfo } from '@crawlee/types';
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import { Configuration, EventType, KeyValueStore } from '@crawlee/core';
2+
3+
import type { Log } from '@apify/log';
4+
import log from '@apify/log';
5+
6+
export interface RecoverableStatePersistenceOptions {
7+
/**
8+
* The key under which the state is stored in the KeyValueStore
9+
*/
10+
persistStateKey: string;
11+
12+
/**
13+
* Flag to enable or disable state persistence
14+
*/
15+
persistenceEnabled?: boolean;
16+
17+
/**
18+
* The name of the KeyValueStore to use for persistence.
19+
* If neither a name nor an id are supplied, the default store will be used.
20+
*/
21+
persistStateKvsName?: string;
22+
23+
/**
24+
* The identifier of the KeyValueStore to use for persistence.
25+
* If neither a name nor an id are supplied, the default store will be used.
26+
*/
27+
persistStateKvsId?: string;
28+
}
29+
30+
/**
31+
* Options for configuring the RecoverableState
32+
*/
33+
export interface RecoverableStateOptions<TStateModel = Record<string, unknown>>
34+
extends RecoverableStatePersistenceOptions {
35+
/**
36+
* The default state used if no persisted state is found.
37+
* A deep copy is made each time the state is used.
38+
*/
39+
defaultState: TStateModel;
40+
41+
/**
42+
* A logger instance for logging operations related to state persistence
43+
*/
44+
logger?: Log;
45+
46+
/**
47+
* Configuration instance to use
48+
*/
49+
config?: Configuration;
50+
51+
/**
52+
* Optional function to transform the state to a JSON string before persistence.
53+
* If not provided, JSON.stringify will be used.
54+
*/
55+
serialize?: (state: TStateModel) => string;
56+
57+
/**
58+
* Optional function to transform a JSON-serialized object back to the state model.
59+
* If not provided, JSON.parse is used.
60+
* It is advisable to perform validation in this function and to throw an exception if it fails.
61+
*/
62+
deserialize?: (serializedState: string) => TStateModel;
63+
}
64+
65+
/**
66+
* A class for managing persistent recoverable state using a plain JavaScript object.
67+
*
68+
* This class facilitates state persistence to a `KeyValueStore`, allowing data to be saved and retrieved
69+
* across migrations or restarts. It manages the loading, saving, and resetting of state data,
70+
* with optional persistence capabilities.
71+
*
72+
* The state is represented by a plain JavaScript object that can be serialized to and deserialized from JSON.
73+
* The class automatically hooks into the event system to persist state when needed.
74+
*/
75+
export class RecoverableState<TStateModel = Record<string, unknown>> {
76+
private readonly defaultState: TStateModel;
77+
private state: TStateModel | null = null;
78+
private readonly persistenceEnabled: boolean;
79+
private readonly persistStateKey: string;
80+
private readonly persistStateKvsName?: string;
81+
private readonly persistStateKvsId?: string;
82+
private keyValueStore: KeyValueStore | null = null;
83+
private readonly log: Log;
84+
private readonly config: Configuration;
85+
private readonly serialize: (state: TStateModel) => string;
86+
private readonly deserialize: (serializedState: string) => TStateModel;
87+
88+
/**
89+
* Initialize a new recoverable state object.
90+
*
91+
* @param options Configuration options for the recoverable state
92+
*/
93+
constructor(options: RecoverableStateOptions<TStateModel>) {
94+
this.defaultState = options.defaultState;
95+
this.persistStateKey = options.persistStateKey;
96+
this.persistenceEnabled = options.persistenceEnabled ?? false;
97+
this.persistStateKvsName = options.persistStateKvsName;
98+
this.persistStateKvsId = options.persistStateKvsId;
99+
this.log = options.logger ?? log.child({ prefix: 'RecoverableState' });
100+
this.config = options.config ?? Configuration.getGlobalConfig();
101+
this.serialize = options.serialize ?? JSON.stringify;
102+
this.deserialize = options.deserialize ?? JSON.parse;
103+
104+
this.persistState = this.persistState.bind(this);
105+
}
106+
107+
/**
108+
* Initialize the recoverable state.
109+
*
110+
* This method must be called before using the recoverable state. It loads the saved state
111+
* if persistence is enabled and registers the object to listen for PERSIST_STATE events.
112+
*
113+
* @returns The loaded state object
114+
*/
115+
async initialize(): Promise<TStateModel> {
116+
if (this.state !== null && this.state !== undefined) {
117+
return this.currentValue;
118+
}
119+
120+
if (!this.persistenceEnabled) {
121+
this.state = this.deserialize(this.serialize(this.defaultState));
122+
return this.currentValue;
123+
}
124+
125+
this.keyValueStore = await KeyValueStore.open(this.persistStateKvsName ?? this.persistStateKvsId, {
126+
config: this.config,
127+
});
128+
129+
await this.loadSavedState();
130+
131+
// Register for persist state events
132+
const eventManager = this.config.getEventManager();
133+
eventManager.on(EventType.PERSIST_STATE, this.persistState);
134+
135+
return this.currentValue;
136+
}
137+
138+
/**
139+
* Clean up resources used by the recoverable state.
140+
*
141+
* If persistence is enabled, this method deregisters the object from PERSIST_STATE events
142+
* and persists the current state one last time.
143+
*/
144+
async teardown(): Promise<void> {
145+
if (!this.persistenceEnabled || !this.persistState) {
146+
return;
147+
}
148+
149+
const eventManager = this.config.getEventManager();
150+
eventManager.off(EventType.PERSIST_STATE, this.persistState);
151+
await this.persistState();
152+
}
153+
154+
/**
155+
* Get the current state.
156+
*/
157+
get currentValue(): TStateModel {
158+
if (this.state === null) {
159+
throw new Error('Recoverable state has not yet been loaded');
160+
}
161+
162+
return this.state;
163+
}
164+
165+
/**
166+
* Reset the state to the default values and clear any persisted state.
167+
*
168+
* Resets the current state to the default state and, if persistence is enabled,
169+
* clears the persisted state from the KeyValueStore.
170+
*/
171+
async reset(): Promise<void> {
172+
this.state = this.deserialize(this.serialize(this.defaultState));
173+
174+
if (this.persistenceEnabled) {
175+
if (this.keyValueStore === null) {
176+
throw new Error('Recoverable state has not yet been initialized');
177+
}
178+
179+
await this.keyValueStore.setValue(this.persistStateKey, null);
180+
}
181+
}
182+
183+
/**
184+
* Persist the current state to the KeyValueStore.
185+
*
186+
* This method is typically called in response to a PERSIST_STATE event, but can also be called
187+
* directly when needed.
188+
*
189+
* @param eventData Optional data associated with a PERSIST_STATE event
190+
*/
191+
async persistState(eventData?: { isMigrating: boolean }): Promise<void> {
192+
this.log.debug(`Persisting state of the RecoverableState (eventData=${JSON.stringify(eventData)}).`);
193+
194+
if (this.keyValueStore === null || this.state === null) {
195+
throw new Error('Recoverable state has not yet been initialized');
196+
}
197+
198+
if (this.persistenceEnabled) {
199+
await this.keyValueStore.setValue(this.persistStateKey, this.serialize(this.state), {
200+
contentType: 'text/plain', // HACK - the result is expected to be JSON, but we do this to avoid the implicit JSON.parse in `KeyValueStore.getValue`
201+
});
202+
}
203+
}
204+
205+
/**
206+
* Load the saved state from the KeyValueStore
207+
*/
208+
private async loadSavedState(): Promise<void> {
209+
if (this.keyValueStore === null) {
210+
throw new Error('Recoverable state has not yet been initialized');
211+
}
212+
213+
const storedState = await this.keyValueStore.getValue(this.persistStateKey);
214+
if (storedState === null || storedState === undefined) {
215+
this.state = this.deserialize(this.serialize(this.defaultState));
216+
} else {
217+
this.state = this.deserialize(storedState as string);
218+
}
219+
}
220+
}

packages/core/tsconfig.build.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
{
22
"extends": "../../tsconfig.build.json",
33
"compilerOptions": {
4-
"outDir": "./dist"
4+
"outDir": "./dist",
5+
"rootDir": "./src"
56
},
67
"include": ["src/**/*"]
78
}

packages/playwright-crawler/src/internals/adaptive-playwright-crawler.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ export interface AdaptivePlaywrightCrawlerOptions
201201
/**
202202
* A custom rendering type predictor
203203
*/
204-
renderingTypePredictor?: Pick<RenderingTypePredictor, 'predict' | 'storeResult'>;
204+
renderingTypePredictor?: Pick<RenderingTypePredictor, 'predict' | 'storeResult' | 'initialize'>;
205205

206206
/**
207207
* Prevent direct access to storage in request handlers (only allow using context helpers).
@@ -314,6 +314,11 @@ export class AdaptivePlaywrightCrawler extends PlaywrightCrawler {
314314
this.preventDirectStorageAccess = preventDirectStorageAccess;
315315
}
316316

317+
protected override async _init(): Promise<void> {
318+
await this.renderingTypePredictor.initialize();
319+
return await super._init();
320+
}
321+
317322
protected override async _runRequestHandler(crawlingContext: PlaywrightCrawlingContext): Promise<void> {
318323
const renderingTypePrediction = this.renderingTypePredictor.predict(crawlingContext.request);
319324
const shouldDetectRenderingType = Math.random() < renderingTypePrediction.detectionProbabilityRecommendation;

packages/playwright-crawler/src/internals/utils/rendering-type-prediction.ts

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import type { Request } from '@crawlee/core';
1+
import type { RecoverableStatePersistenceOptions, Request } from '@crawlee/core';
2+
import { RecoverableState } from '@crawlee/core';
23
import LogisticRegression from 'ml-logistic-regression';
34
import { Matrix } from 'ml-matrix';
45
import stringComparison from 'string-comparison';
@@ -33,6 +34,7 @@ type FeatureVector = [staticResultsSimilarity: number, clientOnlyResultsSimilari
3334
export interface RenderingTypePredictorOptions {
3435
/** A number between 0 and 1 that determines the desired ratio of rendering type detections */
3536
detectionRatio: number;
37+
persistenceOptions?: Partial<RecoverableStatePersistenceOptions>;
3638
}
3739

3840
/**
@@ -43,11 +45,25 @@ export interface RenderingTypePredictorOptions {
4345
export class RenderingTypePredictor {
4446
private renderingTypeDetectionResults = new Map<RenderingType, Map<string | undefined, URLComponents[]>>();
4547
private detectionRatio: number;
46-
private logreg: LogisticRegression;
48+
private state: RecoverableState<{ logreg: LogisticRegression }>;
4749

48-
constructor({ detectionRatio }: RenderingTypePredictorOptions) {
50+
constructor({ detectionRatio, persistenceOptions }: RenderingTypePredictorOptions) {
4951
this.detectionRatio = detectionRatio;
50-
this.logreg = new LogisticRegression({ numSteps: 1000, learningRate: 0.05 });
52+
this.state = new RecoverableState({
53+
defaultState: { logreg: new LogisticRegression({ numSteps: 1000, learningRate: 0.05 }) },
54+
serialize: (state) => JSON.stringify({ logreg: state.logreg.toJSON() }),
55+
deserialize: (serializedState) => ({ logreg: LogisticRegression.load(JSON.parse(serializedState)) }),
56+
persistStateKey: 'rendering-type-predictor-state',
57+
persistenceEnabled: true,
58+
...persistenceOptions,
59+
});
60+
}
61+
62+
/**
63+
* Initialize the predictor by restoring persisted state.
64+
*/
65+
async initialize(): Promise<void> {
66+
await this.state.initialize();
5167
}
5268

5369
/**
@@ -57,18 +73,16 @@ export class RenderingTypePredictor {
5773
renderingType: RenderingType;
5874
detectionProbabilityRecommendation: number;
5975
} {
60-
if (this.logreg.classifiers.length === 0) {
76+
const { logreg } = this.state.currentValue;
77+
if (logreg.classifiers.length === 0) {
6178
return { renderingType: 'clientOnly', detectionProbabilityRecommendation: 1 };
6279
}
6380

6481
const predictionUrl = new URL(loadedUrl ?? url);
6582

6683
const urlFeature = new Matrix([this.calculateFeatureVector(urlComponents(predictionUrl), label)]);
67-
const [prediction] = this.logreg.predict(urlFeature);
68-
const scores = [
69-
this.logreg.classifiers[0].testScores(urlFeature),
70-
this.logreg.classifiers[1].testScores(urlFeature),
71-
];
84+
const [prediction] = logreg.predict(urlFeature);
85+
const scores = [logreg.classifiers[0].testScores(urlFeature), logreg.classifiers[1].testScores(urlFeature)];
7286

7387
return {
7488
renderingType: prediction === 1 ? 'static' : 'clientOnly',
@@ -134,6 +148,6 @@ export class RenderingTypePredictor {
134148
}
135149
}
136150

137-
this.logreg.train(new Matrix(X), Matrix.columnVector(Y));
151+
this.state.currentValue.logreg.train(new Matrix(X), Matrix.columnVector(Y));
138152
}
139153
}

packages/playwright-crawler/src/logistic-regression.d.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,9 @@ declare module 'ml-logistic-regression' {
1818
train(X: Matrix, Y: Matrix): void;
1919

2020
predict(Xtest: Matrix): number[];
21+
22+
static load(model: Record<string, unknown>): LogisticRegression;
23+
24+
toJSON(): Record<string, unknown>;
2125
}
2226
}

packages/utils/tsconfig.build.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
{
22
"extends": "../../tsconfig.build.json",
33
"compilerOptions": {
4-
"outDir": "./dist"
4+
"outDir": "./dist",
5+
"rootDir": "./src"
56
},
6-
"include": ["src/**/*"],
7-
"rootDir": "./src"
7+
"include": ["src/**/*"]
88
}

test/core/crawlers/adaptive_playwright_crawler.test.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ describe('AdaptivePlaywrightCrawler', () => {
9696
detectionProbabilityRecommendation: number;
9797
renderingType: 'clientOnly' | 'static';
9898
}) => ({
99+
initialize: async () => {},
99100
predict: vi.fn((_request: Request) => prediction),
100101
storeResult: vi.fn((_request: Request, _renderingType: string) => {}),
101102
});

0 commit comments

Comments
 (0)