Skip to content

Commit ad505e7

Browse files
committed
added --prune option to convert IB1 trees into IG trees
issue #13
1 parent 0596b57 commit ad505e7

File tree

13 files changed

+232
-54
lines changed

13 files changed

+232
-54
lines changed

include/timbl/GetOptClass.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ namespace Timbl {
9292
bool do_sloppy_loo;
9393
bool do_silly;
9494
bool do_diversify;
95+
bool do_prune;
9596
std::vector<MetricType>metricsArray;
9697
std::ostream *parent_socket_os;
9798
std::string inPath;

include/timbl/IBtree.h

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,11 @@ namespace Timbl {
7878
IBtree( const IBtree& ) = delete; // forbid copies
7979
IBtree& operator=( const IBtree& ) = delete; // forbid copies
8080
~IBtree();
81-
IBtree *Reduce( const TargetValue *, unsigned long&, long );
81+
IBtree *Reduce( const TargetValue *,
82+
unsigned long&,
83+
long,
84+
bool,
85+
ClassDistribution*& );
8286
#ifdef IBSTATS
8387
static inline IBtree *add_feat_val( FeatureValue *,
8488
unsigned int&,
@@ -119,7 +123,6 @@ namespace Timbl {
119123
InstanceBase_base( size_t, unsigned long&, bool, bool );
120124
virtual ~InstanceBase_base( void ) override;
121125
void AssignDefaults( void );
122-
void RedoDistributions();
123126
bool AddInstance( const Instance& );
124127
void RemoveInstance( const Instance& );
125128
void summarizeNodes( std::vector<unsigned int>&,
@@ -164,8 +167,8 @@ namespace Timbl {
164167
Feature_List&,
165168
Targets&,
166169
int );
167-
virtual void Prune( const TargetValue *, long = 0 );
168-
virtual bool IsPruned() const { return false; };
170+
virtual void Prune( const TargetValue *, bool=false, long = 0 );
171+
bool IsPruned() const { return Pruned; };
169172
void CleanPartition( bool );
170173
unsigned long int GetSizeInfo( unsigned long int&, double & ) const;
171174
const ClassDistribution *TopDist() const { return TopDistribution; };
@@ -184,6 +187,7 @@ namespace Timbl {
184187
bool DefaultsValid;
185188
bool Random;
186189
bool PersistentDistributions;
190+
bool Pruned;
187191
int Version;
188192
ClassDistribution *TopDistribution;
189193
WClassDistribution *WTop;
@@ -231,14 +235,17 @@ namespace Timbl {
231235

232236
class IB_InstanceBase: public InstanceBase_base {
233237
public:
234-
IB_InstanceBase( size_t size, unsigned long& cnt, bool rand ):
235-
InstanceBase_base( size, cnt, rand , false ),
238+
IB_InstanceBase( size_t size,
239+
unsigned long& cnt,
240+
bool rand ):
241+
InstanceBase_base( size, cnt, rand , false ),
236242
offSet(0),
237243
effFeat(0),
238244
testInst(0)
239245
{};
240246
IB_InstanceBase *Copy() const override;
241247
IB_InstanceBase *clone() const override;
248+
void Prune( const TargetValue *, bool=false, long = 0 ) override;
242249
const ClassDistribution *InitGraphTest( std::vector<FeatureValue *>&,
243250
const std::vector<FeatureValue *> *,
244251
const size_t,
@@ -255,12 +262,11 @@ namespace Timbl {
255262
public:
256263
IG_InstanceBase( size_t size, unsigned long& cnt,
257264
bool rand, bool pruned, bool keep_dists ):
258-
InstanceBase_base( size, cnt, rand, keep_dists ), Pruned( pruned ) {};
265+
InstanceBase_base( size, cnt, rand, keep_dists ) { Pruned = pruned; };
259266
IG_InstanceBase *clone() const override;
260267
IG_InstanceBase *Copy() const override;
261-
void Prune( const TargetValue *, long = 0 ) override;
268+
void Prune( const TargetValue *, bool = false, long = 0 ) override;
262269
void specialPrune( const TargetValue * );
263-
bool IsPruned() const override { return Pruned; };
264270
const ClassDistribution *IG_test( const Instance& ,
265271
size_t&,
266272
bool&,
@@ -274,8 +280,6 @@ namespace Timbl {
274280
Targets&,
275281
int ) override;
276282
bool MergeSub( InstanceBase_base * ) override;
277-
protected:
278-
bool Pruned;
279283
};
280284

281285
class TRIBL_InstanceBase: public InstanceBase_base {

include/timbl/MBLClass.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ namespace Timbl {
8282
MBLClass& operator=( const MBLClass& );
8383
enum PhaseValue { TrainWords, LearnWords, TestWords, TrainLearnWords };
8484
friend std::ostream& operator<< ( std::ostream&, const PhaseValue& );
85-
enum IB_Stat { Invalid, Normal, Pruned };
85+
enum class IB_Stat { Invalid, Normal, Pruned };
8686

8787
bool writeArrays( std::ostream& );
8888
bool readArrays( std::istream& );
@@ -194,6 +194,7 @@ namespace Timbl {
194194
bool tableFilled;
195195
MetricType globalMetricOption;
196196
bool do_diversify;
197+
bool do_prune;
197198
bool initProbabilityArrays( bool );
198199
void calculatePrestored();
199200
void initDecay();

include/timbl/TimblAPI.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ namespace Timbl{
7777
bool Decrement( const std::string& );
7878
bool Expand( const std::string& );
7979
bool Remove( const std::string& );
80+
bool Prune( bool = false );
8081
bool Test( const std::string& = "",
8182
const std::string& = "",
8283
const std::string& = "" );
@@ -142,6 +143,7 @@ namespace Timbl{
142143
bool ShowOptions( std::ostream& ) const;
143144
bool ShowSettings( std::ostream& ) const;
144145
bool ShowIBInfo( std::ostream& ) const;
146+
bool LearningInfo( std::ostream& ) const;
145147
bool ShowStatistics( std::ostream& ) const;
146148
bool SetOptions( const std::string& );
147149
bool SetIndirectOptions( const TiCC::CL_Options& );

src/GetOptClass.cxx

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ namespace Timbl {
7676
do_sloppy_loo = false;
7777
do_silly = false;
7878
do_diversify = false;
79+
do_prune = false;
7980
if ( MaxFeats == -1 ){
8081
MaxFeats = Max;
8182
LocalInputFormat = UnknownInputFormat; // InputFormat and verbosity
@@ -155,6 +156,7 @@ namespace Timbl {
155156
do_sloppy_loo( false ),
156157
do_silly( in.do_silly ),
157158
do_diversify( in.do_diversify ),
159+
do_prune( in.do_prune ),
158160
metricsArray( in.metricsArray ),
159161
parent_socket_os( in.parent_socket_os ),
160162
outPath( in.outPath ),
@@ -240,6 +242,12 @@ namespace Timbl {
240242
return false;
241243
}
242244
}
245+
if ( do_prune ){
246+
optline = "DO_PRUNE: true";
247+
if ( !Exp->SetOption( optline ) ){
248+
return false;
249+
}
250+
}
243251
if ( f_length > 0 ){
244252
optline = "FLENGTH: " + TiCC::toString<int>(f_length);
245253
if ( !Exp->SetOption( optline ) ){
@@ -943,7 +951,18 @@ namespace Timbl {
943951
break;
944952

945953
case 'p':
946-
local_progress = TiCC::stringTo<int>( value );
954+
if ( longOpt ){
955+
if ( option == "prune" ){
956+
do_prune = true;
957+
}
958+
else {
959+
Error( "invalid option: Did you mean '--prune' ?" );
960+
return false;
961+
}
962+
}
963+
else {
964+
local_progress = TiCC::stringTo<int>( value );
965+
}
947966
break;
948967

949968
case 'q':

src/IBprocs.cxx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,13 @@ namespace Timbl {
8080

8181
MBLClass::IB_Stat MBLClass::IBStatus() const {
8282
if ( !InstanceBase ){
83-
return Invalid;
83+
return IB_Stat::Invalid;
8484
}
8585
else if (InstanceBase->IsPruned() ){
86-
return Pruned;
86+
return IB_Stat::Pruned;
8787
}
8888
else {
89-
return Normal;
89+
return IB_Stat::Normal;
9090
}
9191
}
9292

src/IBtree.cxx

Lines changed: 89 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ namespace Timbl {
211211
os << ")\n";
212212
}
213213

214-
void InstanceBase_base::write_tree_hashed( ostream &os, const IBtree *pnt ) const {
214+
void InstanceBase_base::write_tree_hashed( ostream &os,
215+
const IBtree *pnt ) const {
215216
// part of saving a tree in a recoverable manner
216217
os << "(" << pnt->TValue->Index();
217218
if ( pnt->link ){
@@ -677,6 +678,7 @@ namespace Timbl {
677678
NumOfTails = 0;
678679
DefAss = true; // always for a restored tree
679680
DefaultsValid = true; // always for a restored tree
681+
Pruned = false;
680682
Version = expected_version;
681683
char delim;
682684
is >> delim;
@@ -980,20 +982,70 @@ namespace Timbl {
980982

981983
inline IBtree *IBtree::Reduce( const TargetValue *Top,
982984
unsigned long& cnt,
983-
long depth ){
985+
long depth,
986+
bool keep_dists,
987+
ClassDistribution*& dist ){
984988
// recursively cut default nodes, (with make unique,) starting at the
985989
// leaves of the Tree and moving back to the top.
990+
// when keep_dists is true, gather the distributions upward.
986991
IBtree *pnt = this;
987992
while ( pnt ){
988-
if ( pnt->link != NULL ){
989-
pnt->link = pnt->link->Reduce( pnt->TValue, cnt, depth-1 );
993+
if ( keep_dists ){
994+
if ( pnt->link != NULL ){
995+
ClassDistribution *extra = 0;
996+
pnt->link = pnt->link->Reduce( pnt->TValue,
997+
cnt,
998+
depth-1,
999+
keep_dists,
1000+
extra );
1001+
if ( extra ){
1002+
if ( pnt->TDistribution ){
1003+
pnt->TDistribution->Merge( *extra );
1004+
}
1005+
else {
1006+
pnt->TDistribution = extra->to_VD_Copy();
1007+
}
1008+
if ( dist ){
1009+
dist->Merge( *extra );
1010+
}
1011+
else {
1012+
dist = extra;
1013+
extra = 0;
1014+
}
1015+
delete extra;
1016+
}
1017+
}
1018+
else if ( pnt->TDistribution ){
1019+
if ( dist ){
1020+
dist->Merge( *pnt->TDistribution );
1021+
}
1022+
else {
1023+
dist = pnt->TDistribution->to_VD_Copy();
1024+
}
1025+
}
1026+
}
1027+
else {
1028+
if ( pnt->link != NULL ){
1029+
ClassDistribution *dummy = 0;
1030+
pnt->link = pnt->link->Reduce( pnt->TValue,
1031+
cnt,
1032+
depth-1,
1033+
false,
1034+
dummy );
1035+
}
9901036
}
9911037
pnt = pnt->next;
9921038
}
9931039
if ( depth <= 0 ){
994-
return make_unique( Top, cnt );
1040+
IBtree *out = make_unique( Top, cnt );
1041+
return out;
9951042
}
9961043
else {
1044+
if ( keep_dists &&
1045+
!TDistribution ){
1046+
TDistribution = dist;
1047+
dist = 0;
1048+
}
9971049
return this;
9981050
}
9991051
}
@@ -1039,6 +1091,7 @@ namespace Timbl {
10391091
DefaultsValid( false ),
10401092
Random( Rand ),
10411093
PersistentDistributions( persist ),
1094+
Pruned( false ),
10421095
Version( 4 ),
10431096
TopDistribution( new ClassDistribution ),
10441097
WTop( 0 ),
@@ -1261,14 +1314,37 @@ namespace Timbl {
12611314
DefaultsValid = true;
12621315
}
12631316

1264-
void InstanceBase_base::Prune( const TargetValue *, long ){
1265-
FatalError( "You Cannot Prune this kind of tree! " );
1317+
void InstanceBase_base::Prune( const TargetValue *, bool, long ){
1318+
FatalError( "You can only Prune when using IB1 or IG !" );
12661319
}
12671320

1268-
void IG_InstanceBase::Prune( const TargetValue *top, long depth ){
1269-
AssignDefaults( );
1270-
if ( !Pruned ) {
1271-
InstBase = InstBase->Reduce( top, ibCount, depth );
1321+
void IB_InstanceBase::Prune( const TargetValue *top,
1322+
bool keep_dists,
1323+
long depth ){
1324+
if ( Pruned ) {
1325+
throw runtime_error( "cannot prune a pruned instancebase" );
1326+
}
1327+
else {
1328+
AssignDefaults( );
1329+
ClassDistribution *cd = NULL;
1330+
InstBase = InstBase->Reduce( top, ibCount, depth, keep_dists, cd );
1331+
if ( cd ){
1332+
delete cd;
1333+
}
1334+
Pruned = true;
1335+
}
1336+
}
1337+
1338+
void IG_InstanceBase::Prune( const TargetValue *top,
1339+
bool keep_dists,
1340+
long depth ){
1341+
if ( Pruned ) {
1342+
throw runtime_error( "cannot prune a pruned instancebase" );
1343+
}
1344+
else {
1345+
AssignDefaults( );
1346+
ClassDistribution *cd = NULL;
1347+
InstBase = InstBase->Reduce( top, ibCount, depth, keep_dists, cd );
12721348
Pruned = true;
12731349
}
12741350
}
@@ -1286,7 +1362,8 @@ namespace Timbl {
12861362
}
12871363
bool dummy;
12881364
InstBase->TValue = dist.BestTarget( dummy, Random );
1289-
InstBase = InstBase->Reduce( top, ibCount, 0 );
1365+
ClassDistribution *cd = NULL;
1366+
InstBase = InstBase->Reduce( top, ibCount, 0, false, cd );
12901367
Pruned = true;
12911368
}
12921369

src/IGExperiment.cxx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ namespace Timbl {
240240
if ( PartInstanceBase ){
241241
// time_stamp( "Start Pruning: " );
242242
// cerr << PartInstanceBase << endl;
243-
PartInstanceBase->Prune( TopTarget, 2 );
243+
PartInstanceBase->Prune( TopTarget, false, 2 );
244244
// time_stamp( "Finished Pruning: " );
245245
// cerr << PartInstanceBase << endl;
246246
if ( !TmpInstanceBase->MergeSub( PartInstanceBase ) ){
@@ -341,7 +341,7 @@ namespace Timbl {
341341

342342
bool IG_Experiment::sanityCheck() const {
343343
bool status = true;
344-
if ( IBStatus() != Pruned ){
344+
if ( IBStatus() != IB_Stat::Pruned ){
345345
Warning( "you tried to apply the IGTree algorithm on a complete,"
346346
"(non-pruned) Instance Base" );
347347
status = false;

src/MBLClass.cxx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ namespace Timbl {
106106
&do_silly_testing, false ) );
107107
Options.Add( new BoolOption( "DO_DIVERSIFY",
108108
&do_diversify, false ) );
109+
Options.Add( new BoolOption( "DO_PRUNE",
110+
&do_prune, false ) );
109111
Options.Add( new DecayOption( "DECAY",
110112
&decay_flag, Zero ) );
111113
Options.Add( new IntegerOption( "SEED",
@@ -217,6 +219,7 @@ namespace Timbl {
217219
tableFilled(false),
218220
globalMetricOption(Overlap),
219221
do_diversify(false),
222+
do_prune(false),
220223
ChopInput(0),
221224
F_length(0),
222225
MaxFeatures(0),
@@ -287,6 +290,7 @@ namespace Timbl {
287290
do_sloppy_loo = m.do_sloppy_loo;
288291
do_silly_testing = m.do_silly_testing;
289292
do_diversify = m.do_diversify;
293+
do_prune = m.do_prune;
290294
tester = 0;
291295
decay = 0;
292296
targets = m.targets;

0 commit comments

Comments
 (0)