@@ -10,24 +10,14 @@ library(dplyr)
1010# 1. Define the DecisionTree and MarkovModel classes
1111# ##############################
1212
13- # DecisionTree constructors and methods
1413DecisionTree <- function (data , ... ) {
15- UseMethod(" DecisionTree" )
16- }
17-
18- DecisionTree.default <- function (data , ... ) {
1914 call_obj <- match.call()
2015 structure(list (data = data ),
2116 call = call_obj ,
2217 class = c(" DecisionTree" , " Model" ))
2318}
2419
25- # MarkovModel constructors and methods
26- MarkovModel <- function (model , ... ) {
27- UseMethod(" MarkovModel" )
28- }
29-
30- MarkovModel.default <- function (model = NA ,
20+ MarkovModel <- function (model = NA ,
3121 init_probs = NA ,
3222 trans_matrix = NA ,
3323 cost_matrix = NA ,
@@ -58,21 +48,6 @@ CombinedModel <- function(...) {
5848 structure(models , class = " CombinedModel" )
5949}
6050
61- # Infix operator for chaining (optional)
62- `%->%` <- function (mod1 , mod2 ) {
63- UseMethod(" %->%" )
64- }
65- `%->%.default` <- function (mod1 , mod2 ) {
66- if (! inherits(mod2 , " Model" )) {
67- stop(" All arguments must be of class 'Model'" )
68- } else {
69- stop(" No method for this model" )
70- }
71- }
72- `%->%.Model` <- function (mod1 , mod2 ) {
73- CombinedModel(mod1 , mod2 )
74- }
75-
7651# ##############################
7752# 3. Define the runner functions for each model
7853# ##############################
@@ -112,8 +87,8 @@ run_model.MarkovModel <- function(model) {
11287 class = c(" MarkovModelOutput" , " output" , class(model )))
11388}
11489
115- # Run a CombinedModel: update each model in the chain (using update_model())
116- # and store each model's output in a list.
90+ # update each model in the chain (using update_model())
91+ # and store each model's output in a list
11792run_model.CombinedModel <- function (model_chain ) {
11893 results <- list ()
11994
@@ -122,14 +97,14 @@ run_model.CombinedModel <- function(model_chain) {
12297
12398 if (i > 1 ) {
12499 # Update current_model based on the output of the previous model
125- # The update_model() S3 function inspects the previous result’s class
100+ # S3 function inspects the previous result’s class
126101 current_model <- update_model(current_model , results [[i - 1 ]])
127102 }
128103
129104 results [[i ]] <- run_model(current_model )
130105 }
131106
132- structure(results , class = c(" CombinedModelOutput" , " output " ))
107+ structure(results , class = c(" CombinedModelOutput" , class( results ) ))
133108}
134109
135110# ##############################
0 commit comments