11import type { Request } from '@crawlee/core' ;
2+ import { type PersistenceOptions , RecoverableState } from '@crawlee/utils' ;
23import LogisticRegression from 'ml-logistic-regression' ;
34import { Matrix } from 'ml-matrix' ;
45import stringComparison from 'string-comparison' ;
@@ -33,6 +34,7 @@ type FeatureVector = [staticResultsSimilarity: number, clientOnlyResultsSimilari
3334export interface RenderingTypePredictorOptions {
3435 /** A number between 0 and 1 that determines the desired ratio of rendering type detections */
3536 detectionRatio : number ;
37+ persistenceOptions ?: Partial < PersistenceOptions > ;
3638}
3739
3840/**
@@ -43,11 +45,22 @@ export interface RenderingTypePredictorOptions {
4345export class RenderingTypePredictor {
4446 private renderingTypeDetectionResults = new Map < RenderingType , Map < string | undefined , URLComponents [ ] > > ( ) ;
4547 private detectionRatio : number ;
46- private logreg : LogisticRegression ;
48+ private state : RecoverableState < { logreg : LogisticRegression } > ;
4749
48- constructor ( { detectionRatio } : RenderingTypePredictorOptions ) {
50+ constructor ( { detectionRatio, persistenceOptions } : RenderingTypePredictorOptions ) {
4951 this . detectionRatio = detectionRatio ;
50- this . logreg = new LogisticRegression ( { numSteps : 1000 , learningRate : 0.05 } ) ;
52+ this . state = new RecoverableState ( {
53+ defaultState : { logreg : new LogisticRegression ( { numSteps : 1000 , learningRate : 0.05 } ) } ,
54+ serialize : ( state ) => ( { logreg : state . logreg . toJSON ( ) } ) ,
55+ deserialize : ( serializedState ) => ( { logreg : LogisticRegression . load ( serializedState ) } ) ,
56+ ...{
57+ ...{
58+ persistStateKey : 'rendering-type-predictor-state' ,
59+ persistenceEnabled : true ,
60+ } ,
61+ ...persistenceOptions ,
62+ } ,
63+ } ) ;
5164 }
5265
5366 /**
@@ -57,18 +70,16 @@ export class RenderingTypePredictor {
5770 renderingType : RenderingType ;
5871 detectionProbabilityRecommendation : number ;
5972 } {
60- if ( this . logreg . classifiers . length === 0 ) {
73+ const { logreg } = this . state . currentValue ;
74+ if ( logreg . classifiers . length === 0 ) {
6175 return { renderingType : 'clientOnly' , detectionProbabilityRecommendation : 1 } ;
6276 }
6377
6478 const predictionUrl = new URL ( loadedUrl ?? url ) ;
6579
6680 const urlFeature = new Matrix ( [ this . calculateFeatureVector ( urlComponents ( predictionUrl ) , label ) ] ) ;
67- const [ prediction ] = this . logreg . predict ( urlFeature ) ;
68- const scores = [
69- this . logreg . classifiers [ 0 ] . testScores ( urlFeature ) ,
70- this . logreg . classifiers [ 1 ] . testScores ( urlFeature ) ,
71- ] ;
81+ const [ prediction ] = logreg . predict ( urlFeature ) ;
82+ const scores = [ logreg . classifiers [ 0 ] . testScores ( urlFeature ) , logreg . classifiers [ 1 ] . testScores ( urlFeature ) ] ;
7283
7384 return {
7485 renderingType : prediction === 1 ? 'static' : 'clientOnly' ,
@@ -134,6 +145,6 @@ export class RenderingTypePredictor {
134145 }
135146 }
136147
137- this . logreg . train ( new Matrix ( X ) , Matrix . columnVector ( Y ) ) ;
148+ this . state . currentValue . logreg . train ( new Matrix ( X ) , Matrix . columnVector ( Y ) ) ;
138149 }
139150}
0 commit comments