Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ export * from './session_pool';
export * from './storages';
export * from './validators';
export * from './cookie_utils';
export * from './recoverable_state';
export { PseudoUrl } from '@apify/pseudo_url';
export { Dictionary, Awaitable, Constructor, StorageClient, Cookie, QueueOperationInfo } from '@crawlee/types';
220 changes: 220 additions & 0 deletions packages/core/src/recoverable_state.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import { Configuration, EventType, KeyValueStore } from '@crawlee/core';

import type { Log } from '@apify/log';
import log from '@apify/log';

export interface RecoverableStatePersistenceOptions {
/**
* The key under which the state is stored in the KeyValueStore
*/
persistStateKey: string;

/**
* Flag to enable or disable state persistence
*/
persistenceEnabled?: boolean;

/**
* The name of the KeyValueStore to use for persistence.
* If neither a name nor an id are supplied, the default store will be used.
*/
persistStateKvsName?: string;

/**
* The identifier of the KeyValueStore to use for persistence.
* If neither a name nor an id are supplied, the default store will be used.
*/
persistStateKvsId?: string;
}

/**
* Options for configuring the RecoverableState
*/
export interface RecoverableStateOptions<TStateModel = Record<string, unknown>>
extends RecoverableStatePersistenceOptions {
/**
* The default state used if no persisted state is found.
* A deep copy is made each time the state is used.
*/
defaultState: TStateModel;

/**
* A logger instance for logging operations related to state persistence
*/
logger?: Log;

/**
* Configuration instance to use
*/
config?: Configuration;

/**
* Optional function to transform the state to a JSON string before persistence.
* If not provided, JSON.stringify will be used.
*/
serialize?: (state: TStateModel) => string;

/**
* Optional function to transform a JSON-serialized object back to the state model.
* If not provided, JSON.parse is used.
* It is advisable to perform validation in this function and to throw an exception if it fails.
*/
deserialize?: (serializedState: string) => TStateModel;
}

/**
* A class for managing persistent recoverable state using a plain JavaScript object.
*
* This class facilitates state persistence to a `KeyValueStore`, allowing data to be saved and retrieved
* across migrations or restarts. It manages the loading, saving, and resetting of state data,
* with optional persistence capabilities.
*
* The state is represented by a plain JavaScript object that can be serialized to and deserialized from JSON.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this is no longer completely true with custom serialize and deserialize methods, but there is IMO nothing wrong with not telling the users the whole truth, keeping them on the safe side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, "plain JavaScript object" is not exactly a legal term and we do accept stuff that can be (de)serialized from/to JSON, if only with the caveat that sometimes you have to write the serialization logic by hand.

image

* The class automatically hooks into the event system to persist state when needed.
*/
export class RecoverableState<TStateModel = Record<string, unknown>> {
private readonly defaultState: TStateModel;
private state: TStateModel | null = null;
private readonly persistenceEnabled: boolean;
private readonly persistStateKey: string;
private readonly persistStateKvsName?: string;
private readonly persistStateKvsId?: string;
private keyValueStore: KeyValueStore | null = null;
private readonly log: Log;
private readonly config: Configuration;
private readonly serialize: (state: TStateModel) => string;
private readonly deserialize: (serializedState: string) => TStateModel;

/**
* Initialize a new recoverable state object.
*
* @param options Configuration options for the recoverable state
*/
constructor(options: RecoverableStateOptions<TStateModel>) {
this.defaultState = options.defaultState;
this.persistStateKey = options.persistStateKey;
this.persistenceEnabled = options.persistenceEnabled ?? false;
this.persistStateKvsName = options.persistStateKvsName;
this.persistStateKvsId = options.persistStateKvsId;
this.log = options.logger ?? log.child({ prefix: 'RecoverableState' });
this.config = options.config ?? Configuration.getGlobalConfig();
this.serialize = options.serialize ?? JSON.stringify;
this.deserialize = options.deserialize ?? JSON.parse;

this.persistState = this.persistState.bind(this);
}

/**
* Initialize the recoverable state.
*
* This method must be called before using the recoverable state. It loads the saved state
* if persistence is enabled and registers the object to listen for PERSIST_STATE events.
*
* @returns The loaded state object
*/
async initialize(): Promise<TStateModel> {
if (this.state !== null && this.state !== undefined) {
return this.currentValue;
}

if (!this.persistenceEnabled) {
this.state = this.deserialize(this.serialize(this.defaultState));
return this.currentValue;
}

this.keyValueStore = await KeyValueStore.open(this.persistStateKvsName ?? this.persistStateKvsId, {
config: this.config,
});

await this.loadSavedState();

// Register for persist state events
const eventManager = this.config.getEventManager();
eventManager.on(EventType.PERSIST_STATE, this.persistState);

return this.currentValue;
}

/**
* Clean up resources used by the recoverable state.
*
* If persistence is enabled, this method deregisters the object from PERSIST_STATE events
* and persists the current state one last time.
*/
async teardown(): Promise<void> {
if (!this.persistenceEnabled || !this.persistState) {
return;
}

const eventManager = this.config.getEventManager();
eventManager.off(EventType.PERSIST_STATE, this.persistState);
await this.persistState();
}

/**
* Get the current state.
*/
get currentValue(): TStateModel {
if (this.state === null) {
throw new Error('Recoverable state has not yet been loaded');
}

return this.state;
}

/**
* Reset the state to the default values and clear any persisted state.
*
* Resets the current state to the default state and, if persistence is enabled,
* clears the persisted state from the KeyValueStore.
*/
async reset(): Promise<void> {
this.state = this.deserialize(this.serialize(this.defaultState));

if (this.persistenceEnabled) {
if (this.keyValueStore === null) {
throw new Error('Recoverable state has not yet been initialized');
}

await this.keyValueStore.setValue(this.persistStateKey, null);
}
}

/**
* Persist the current state to the KeyValueStore.
*
* This method is typically called in response to a PERSIST_STATE event, but can also be called
* directly when needed.
*
* @param eventData Optional data associated with a PERSIST_STATE event
*/
async persistState(eventData?: { isMigrating: boolean }): Promise<void> {
this.log.debug(`Persisting state of the RecoverableState (eventData=${JSON.stringify(eventData)}).`);

if (this.keyValueStore === null || this.state === null) {
throw new Error('Recoverable state has not yet been initialized');
}

if (this.persistenceEnabled) {
await this.keyValueStore.setValue(this.persistStateKey, this.serialize(this.state), {
contentType: 'text/plain', // HACK - the result is expected to be JSON, but we do this to avoid the implicit JSON.parse in `KeyValueStore.getValue`
});
}
}

/**
* Load the saved state from the KeyValueStore
*/
private async loadSavedState(): Promise<void> {
if (this.keyValueStore === null) {
throw new Error('Recoverable state has not yet been initialized');
}

const storedState = await this.keyValueStore.getValue(this.persistStateKey);
if (storedState === null || storedState === undefined) {
this.state = this.deserialize(this.serialize(this.defaultState));
} else {
this.state = this.deserialize(storedState as string);
}
}
}
3 changes: 2 additions & 1 deletion packages/core/tsconfig.build.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"extends": "../../tsconfig.build.json",
"compilerOptions": {
"outDir": "./dist"
"outDir": "./dist",
"rootDir": "./src"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you really need this (and the changes in the other tsconfig)? feels wrong to change it just for two packages

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

},
"include": ["src/**/*"]
}
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ export interface AdaptivePlaywrightCrawlerOptions
/**
* A custom rendering type predictor
*/
renderingTypePredictor?: Pick<RenderingTypePredictor, 'predict' | 'storeResult'>;
renderingTypePredictor?: Pick<RenderingTypePredictor, 'predict' | 'storeResult' | 'initialize'>;

/**
* Prevent direct access to storage in request handlers (only allow using context helpers).
Expand Down Expand Up @@ -314,6 +314,11 @@ export class AdaptivePlaywrightCrawler extends PlaywrightCrawler {
this.preventDirectStorageAccess = preventDirectStorageAccess;
}

protected override async _init(): Promise<void> {
await this.renderingTypePredictor.initialize();
return await super._init();
}

protected override async _runRequestHandler(crawlingContext: PlaywrightCrawlingContext): Promise<void> {
const renderingTypePrediction = this.renderingTypePredictor.predict(crawlingContext.request);
const shouldDetectRenderingType = Math.random() < renderingTypePrediction.detectionProbabilityRecommendation;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { Request } from '@crawlee/core';
import type { RecoverableStatePersistenceOptions, Request } from '@crawlee/core';
import { RecoverableState } from '@crawlee/core';
import LogisticRegression from 'ml-logistic-regression';
import { Matrix } from 'ml-matrix';
import stringComparison from 'string-comparison';
Expand Down Expand Up @@ -33,6 +34,7 @@ type FeatureVector = [staticResultsSimilarity: number, clientOnlyResultsSimilari
export interface RenderingTypePredictorOptions {
/** A number between 0 and 1 that determines the desired ratio of rendering type detections */
detectionRatio: number;
persistenceOptions?: Partial<RecoverableStatePersistenceOptions>;
}

/**
Expand All @@ -43,11 +45,25 @@ export interface RenderingTypePredictorOptions {
export class RenderingTypePredictor {
private renderingTypeDetectionResults = new Map<RenderingType, Map<string | undefined, URLComponents[]>>();
private detectionRatio: number;
private logreg: LogisticRegression;
private state: RecoverableState<{ logreg: LogisticRegression }>;

constructor({ detectionRatio }: RenderingTypePredictorOptions) {
constructor({ detectionRatio, persistenceOptions }: RenderingTypePredictorOptions) {
this.detectionRatio = detectionRatio;
this.logreg = new LogisticRegression({ numSteps: 1000, learningRate: 0.05 });
this.state = new RecoverableState({
defaultState: { logreg: new LogisticRegression({ numSteps: 1000, learningRate: 0.05 }) },
serialize: (state) => JSON.stringify({ logreg: state.logreg.toJSON() }),
deserialize: (serializedState) => ({ logreg: LogisticRegression.load(JSON.parse(serializedState)) }),
persistStateKey: 'rendering-type-predictor-state',
persistenceEnabled: true,
...persistenceOptions,
});
}

/**
* Initialize the predictor by restoring persisted state.
*/
async initialize(): Promise<void> {
await this.state.initialize();
}

/**
Expand All @@ -57,18 +73,16 @@ export class RenderingTypePredictor {
renderingType: RenderingType;
detectionProbabilityRecommendation: number;
} {
if (this.logreg.classifiers.length === 0) {
const { logreg } = this.state.currentValue;
if (logreg.classifiers.length === 0) {
return { renderingType: 'clientOnly', detectionProbabilityRecommendation: 1 };
}

const predictionUrl = new URL(loadedUrl ?? url);

const urlFeature = new Matrix([this.calculateFeatureVector(urlComponents(predictionUrl), label)]);
const [prediction] = this.logreg.predict(urlFeature);
const scores = [
this.logreg.classifiers[0].testScores(urlFeature),
this.logreg.classifiers[1].testScores(urlFeature),
];
const [prediction] = logreg.predict(urlFeature);
const scores = [logreg.classifiers[0].testScores(urlFeature), logreg.classifiers[1].testScores(urlFeature)];

return {
renderingType: prediction === 1 ? 'static' : 'clientOnly',
Expand Down Expand Up @@ -134,6 +148,6 @@ export class RenderingTypePredictor {
}
}

this.logreg.train(new Matrix(X), Matrix.columnVector(Y));
this.state.currentValue.logreg.train(new Matrix(X), Matrix.columnVector(Y));
}
}
4 changes: 4 additions & 0 deletions packages/playwright-crawler/src/logistic-regression.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,9 @@ declare module 'ml-logistic-regression' {
train(X: Matrix, Y: Matrix): void;

predict(Xtest: Matrix): number[];

static load(model: Record<string, unknown>): LogisticRegression;

toJSON(): Record<string, unknown>;
}
}
6 changes: 3 additions & 3 deletions packages/utils/tsconfig.build.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"extends": "../../tsconfig.build.json",
"compilerOptions": {
"outDir": "./dist"
"outDir": "./dist",
"rootDir": "./src"
},
"include": ["src/**/*"],
"rootDir": "./src"
"include": ["src/**/*"]
}
1 change: 1 addition & 0 deletions test/core/crawlers/adaptive_playwright_crawler.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ describe('AdaptivePlaywrightCrawler', () => {
detectionProbabilityRecommendation: number;
renderingType: 'clientOnly' | 'static';
}) => ({
initialize: async () => {},
predict: vi.fn((_request: Request) => prediction),
storeResult: vi.fn((_request: Request, _renderingType: string) => {}),
});
Expand Down
Loading