Skip to content

Commit 6c1cb35

Browse files
committed
Use RecoverableState in RenderingTypePredictor
1 parent 76427ff commit 6c1cb35

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

packages/playwright-crawler/src/internals/utils/rendering-type-prediction.ts

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { Request } from '@crawlee/core';
2+
import { type PersistenceOptions, RecoverableState } from '@crawlee/utils';
23
import LogisticRegression from 'ml-logistic-regression';
34
import { Matrix } from 'ml-matrix';
45
import stringComparison from 'string-comparison';
@@ -33,6 +34,7 @@ type FeatureVector = [staticResultsSimilarity: number, clientOnlyResultsSimilari
3334
export interface RenderingTypePredictorOptions {
3435
/** A number between 0 and 1 that determines the desired ratio of rendering type detections */
3536
detectionRatio: number;
37+
persistenceOptions?: Partial<PersistenceOptions>;
3638
}
3739

3840
/**
@@ -43,11 +45,22 @@ export interface RenderingTypePredictorOptions {
4345
export class RenderingTypePredictor {
4446
private renderingTypeDetectionResults = new Map<RenderingType, Map<string | undefined, URLComponents[]>>();
4547
private detectionRatio: number;
46-
private logreg: LogisticRegression;
48+
private state: RecoverableState<{ logreg: LogisticRegression }>;
4749

48-
constructor({ detectionRatio }: RenderingTypePredictorOptions) {
50+
constructor({ detectionRatio, persistenceOptions }: RenderingTypePredictorOptions) {
4951
this.detectionRatio = detectionRatio;
50-
this.logreg = new LogisticRegression({ numSteps: 1000, learningRate: 0.05 });
52+
this.state = new RecoverableState({
53+
defaultState: { logreg: new LogisticRegression({ numSteps: 1000, learningRate: 0.05 }) },
54+
serialize: (state) => ({ logreg: state.logreg.toJSON() }),
55+
deserialize: (serializedState) => ({ logreg: LogisticRegression.load(serializedState) }),
56+
...{
57+
...{
58+
persistStateKey: 'rendering-type-predictor-state',
59+
persistenceEnabled: true,
60+
},
61+
...persistenceOptions,
62+
},
63+
});
5164
}
5265

5366
/**
@@ -57,18 +70,16 @@ export class RenderingTypePredictor {
5770
renderingType: RenderingType;
5871
detectionProbabilityRecommendation: number;
5972
} {
60-
if (this.logreg.classifiers.length === 0) {
73+
const { logreg } = this.state.currentValue;
74+
if (logreg.classifiers.length === 0) {
6175
return { renderingType: 'clientOnly', detectionProbabilityRecommendation: 1 };
6276
}
6377

6478
const predictionUrl = new URL(loadedUrl ?? url);
6579

6680
const urlFeature = new Matrix([this.calculateFeatureVector(urlComponents(predictionUrl), label)]);
67-
const [prediction] = this.logreg.predict(urlFeature);
68-
const scores = [
69-
this.logreg.classifiers[0].testScores(urlFeature),
70-
this.logreg.classifiers[1].testScores(urlFeature),
71-
];
81+
const [prediction] = logreg.predict(urlFeature);
82+
const scores = [logreg.classifiers[0].testScores(urlFeature), logreg.classifiers[1].testScores(urlFeature)];
7283

7384
return {
7485
renderingType: prediction === 1 ? 'static' : 'clientOnly',
@@ -134,6 +145,6 @@ export class RenderingTypePredictor {
134145
}
135146
}
136147

137-
this.logreg.train(new Matrix(X), Matrix.columnVector(Y));
148+
this.state.currentValue.logreg.train(new Matrix(X), Matrix.columnVector(Y));
138149
}
139150
}

0 commit comments

Comments
 (0)