diff --git a/blocking/inc/gmds/blocking/CurvedBlocking.h b/blocking/inc/gmds/blocking/CurvedBlocking.h index c4d756a51..47149d57d 100644 --- a/blocking/inc/gmds/blocking/CurvedBlocking.h +++ b/blocking/inc/gmds/blocking/CurvedBlocking.h @@ -22,6 +22,15 @@ namespace gmds { /*----------------------------------------------------------------------------*/ namespace blocking { /*----------------------------------------------------------------------------*/ +class Counter{ + public: + Counter(int c) + : m_counter_global_id(c){} + int get_and_increment_id(){return m_counter_global_id++;} + int value(){return m_counter_global_id;} + private: + int m_counter_global_id; +}; /**@struct CellInfo * @brief This structure gather the pieces of data that are shared by any * blocking cell. Each cell is defined by: @@ -41,22 +50,29 @@ struct CellInfo int topo_id; /*** link to the cad manager to have access to geometric cells */ cad::GeomManager* geom_manager; + /*** link to the counter used to assign a unique id to each entity */ + Counter* counter; /*** dimension of the geometrical cell we are classifid on */ int geom_dim; /*** unique id of the geomtrical cell */ int geom_id; - /*** global counter used to assign an unique id to each block */ - static int m_counter_global_id; /** @brief Constructor + * @param Ac the id counter; the CGAL gmap copy constructor requires a CellInfo() + * call with no params * @param AManager the geometric manager to access cells * @param ATopoDim Cell dimension * @param AGeomDim on-classify geometric cell dimension (4 if not classified) * @param AGeomId on-classify geometric cell unique id */ - CellInfo(cad::GeomManager* AManager, const int ATopoDim = 4, const int AGeomDim = 4, const int AGeomId = NullID) : - topo_dim(ATopoDim), topo_id(m_counter_global_id++), geom_manager(AManager),geom_dim(AGeomDim), geom_id(AGeomId) + CellInfo(Counter* Ac=nullptr, cad::GeomManager* AManager=nullptr, const int ATopoDim = 4, const int AGeomDim = 4, const int AGeomId = NullID) : + topo_dim(ATopoDim), geom_manager(AManager), counter(Ac), geom_dim(AGeomDim), geom_id(AGeomId) { + if(Ac != nullptr) { + topo_id = Ac->get_and_increment_id(); + } else { + topo_id = -1; + } } }; /*----------------------------------------------------------------------------*/ @@ -69,13 +85,15 @@ struct NodeInfo : CellInfo /*** node location in space, i.e. a single point */ math::Point point; /** @brief Constructor + * @param Ac the id counter; the CGAL gmap copy constructor requires a CellInfo() + * call with no params * @param AManager the geometric manager to access cells * @param AGeomDim on-classify geometric cell dimension (4 if not classified) * @param AGeomId on-classify geometric cell unique id * @param APoint geometric location */ - NodeInfo(cad::GeomManager* AManager, const int AGeomDim = 4, const int AGeomId = NullID, const math::Point &APoint = math::Point(0, 0, 0)) : - CellInfo(AManager, 0, AGeomDim, AGeomId), point(APoint) + NodeInfo(Counter* Ac=nullptr, cad::GeomManager* AManager=nullptr, const int AGeomDim = 4, const int AGeomId = NullID, const math::Point &APoint = math::Point(0, 0, 0)) : + CellInfo(Ac, AManager, 0, AGeomDim, AGeomId), point(APoint) { } }; @@ -208,8 +226,7 @@ struct SplitFunctor ca2.info().geom_dim = ca1.info().geom_dim; ca2.info().geom_id = ca1.info().geom_id; ca2.info().topo_dim = ca1.info().topo_dim; - ca2.info().topo_id = CellInfo::m_counter_global_id++; - + ca2.info().topo_id = ca1.info().counter->get_and_increment_id(); } }; @@ -227,7 +244,7 @@ struct SplitFunctorNode ca2.info().geom_id = ca1.info().geom_id; ca2.info().point = ca1.info().point; ca2.info().topo_dim = ca1.info().topo_dim; - ca2.info().topo_id = CellInfo::m_counter_global_id++; + ca2.info().topo_id = ca1.info().counter->get_and_increment_id(); } }; /*----------------------------------------------------------------------------*/ @@ -278,6 +295,8 @@ class LIB_GMDS_BLOCKING_API CurvedBlocking */ CurvedBlocking(cad::GeomManager *AGeomModel, bool AInitAsBoundingBox = false); + CurvedBlocking(const CurvedBlocking &ABl); + /** @brief Destructor */ virtual ~CurvedBlocking(); @@ -685,6 +704,11 @@ class LIB_GMDS_BLOCKING_API CurvedBlocking */ std::vector> get_projection_info(math::Point &AP, std::vector &AEdges); + Counter* getCounter() + { + return &m_counter; + } + private: /**@brief Mark with @p AMark all the darts of orbit <0,1>(@p ADart) @@ -736,6 +760,9 @@ class LIB_GMDS_BLOCKING_API CurvedBlocking cad::GeomManager *m_geom_model; /*** the underlying n-g-map model*/ GMap3 m_gmap; + + /*** id counter*/ + Counter m_counter; }; /*----------------------------------------------------------------------------*/ } // namespace blocking diff --git a/blocking/src/CurvedBlocking.cpp b/blocking/src/CurvedBlocking.cpp index b6ff874cc..de461f295 100644 --- a/blocking/src/CurvedBlocking.cpp +++ b/blocking/src/CurvedBlocking.cpp @@ -4,10 +4,12 @@ using namespace gmds; using namespace gmds::blocking; /*----------------------------------------------------------------------------*/ -int CellInfo::m_counter_global_id = 0; +//int CellInfo::m_counter_global_id = 0; /*----------------------------------------------------------------------------*/ -CurvedBlocking::CurvedBlocking(cad::GeomManager *AGeomModel, bool AInitAsBoundingBox) : m_geom_model(AGeomModel) { +CurvedBlocking::CurvedBlocking(cad::GeomManager *AGeomModel, bool AInitAsBoundingBox) + : m_geom_model(AGeomModel), m_counter(0) +{ if (AInitAsBoundingBox) { TCoord min[3] = {MAXFLOAT, MAXFLOAT, MAXFLOAT}; TCoord max[3] = {-MAXFLOAT, -MAXFLOAT, -MAXFLOAT}; @@ -32,7 +34,27 @@ CurvedBlocking::CurvedBlocking(cad::GeomManager *AGeomModel, bool AInitAsBoundin create_block(p1, p2, p3, p4, p5, p6, p7, p8); } } - +/*----------------------------------------------------------------------------*/ +CurvedBlocking::CurvedBlocking(const CurvedBlocking &ABl) +: m_geom_model(ABl.m_geom_model), m_gmap(ABl.m_gmap), m_counter(ABl.m_counter) +{ + auto listBlocks = get_all_blocks(); + for(auto b : listBlocks){ + b->info().counter = &m_counter; + } + auto listFaces = get_all_faces(); + for(auto b : listFaces){ + b->info().counter = &m_counter; + } + auto listEdges = get_all_edges(); + for(auto b : listEdges){ + b->info().counter = &m_counter; + } + auto listNodes = get_all_nodes(); + for(auto b : listNodes){ + b->info().counter = &m_counter; + } +} /*----------------------------------------------------------------------------*/ CurvedBlocking::~CurvedBlocking() {} @@ -51,25 +73,25 @@ CurvedBlocking::geom_model() { /*----------------------------------------------------------------------------*/ CurvedBlocking::Node CurvedBlocking::create_node(const int AGeomDim, const int AGeomId, const math::Point &APoint) { - return m_gmap.create_attribute<0>(NodeInfo(m_geom_model,AGeomDim, AGeomId, APoint)); + return m_gmap.create_attribute<0>(NodeInfo(this->getCounter(),m_geom_model,AGeomDim, AGeomId, APoint)); } /*----------------------------------------------------------------------------*/ CurvedBlocking::Edge CurvedBlocking::create_edge(const int AGeomDim, const int AGeomId) { - return m_gmap.create_attribute<1>(CellInfo(m_geom_model,1, AGeomDim, AGeomId)); + return m_gmap.create_attribute<1>(CellInfo(this->getCounter(),m_geom_model,1, AGeomDim, AGeomId)); } /*----------------------------------------------------------------------------*/ CurvedBlocking::Face CurvedBlocking::create_face(const int AGeomDim, const int AGeomId) { - return m_gmap.create_attribute<2>(CellInfo(m_geom_model,2, AGeomDim, AGeomId)); + return m_gmap.create_attribute<2>(CellInfo(this->getCounter(),m_geom_model,2, AGeomDim, AGeomId)); } /*----------------------------------------------------------------------------*/ CurvedBlocking::Block CurvedBlocking::create_block(const int AGeomDim, const int AGeomId) { - return m_gmap.create_attribute<3>(CellInfo(m_geom_model,3, AGeomDim, AGeomId)); + return m_gmap.create_attribute<3>(CellInfo(this->getCounter(),m_geom_model,3, AGeomDim, AGeomId)); } /*----------------------------------------------------------------------------*/ diff --git a/blocking/src/CurvedBlockingClassifier.cpp b/blocking/src/CurvedBlockingClassifier.cpp index cad4c0fda..83ba7ce79 100644 --- a/blocking/src/CurvedBlockingClassifier.cpp +++ b/blocking/src/CurvedBlockingClassifier.cpp @@ -632,7 +632,7 @@ std::vector> CurvedBlockingClassifier::list_Possible_Cuts() { std::vector> list_actions; - auto no_capt_elements = classify(); + auto no_capt_elements = this->classify(); auto no_points_capt = no_capt_elements.non_captured_points; auto no_curves_capt = no_capt_elements.non_captured_curves; diff --git a/blocking/tst/ExecutionActionsTestSuite.h b/blocking/tst/ExecutionActionsTestSuite.h index bf32499b4..27711d43d 100644 --- a/blocking/tst/ExecutionActionsTestSuite.h +++ b/blocking/tst/ExecutionActionsTestSuite.h @@ -706,3 +706,84 @@ TEST(ExecutionActionsTestSuite,cb5){ vtk_writer_edges.write("debug_blocking_edges.vtk"); } + + +TEST(ExecutionActionsTestSuite,cb2_auto) { + gmds::cad::FACManager geom_model; + set_up_file(&geom_model,"cb2.vtk"); + gmds::blocking::CurvedBlocking bl(&geom_model,true); + gmds::blocking::CurvedBlockingClassifier classifier(&bl); + + + classifier.clear_classification(); + + auto errors = classifier.classify(); + + + //Check nb points of the geometry and nb nodes of the blocking + ASSERT_EQ(16,geom_model.getNbPoints()); + ASSERT_EQ(24,geom_model.getNbCurves()); + ASSERT_EQ(10,geom_model.getNbSurfaces()); + ASSERT_EQ(8,bl.get_all_nodes().size()); + ASSERT_EQ(12,bl.get_all_edges().size()); + ASSERT_EQ(6,bl.get_all_faces().size()); + + + + //Check elements class and captured + //Check nb nodes/edges/faces no classified + ASSERT_EQ(0,errors.non_classified_nodes.size()); + ASSERT_EQ(0,errors.non_classified_edges.size()); + ASSERT_EQ(6,errors.non_classified_faces.size()); + + //Check nb points/curves/surfaces no captured + ASSERT_EQ(8,errors.non_captured_points.size()); + ASSERT_EQ(12,errors.non_captured_curves.size()); + ASSERT_EQ(10,errors.non_captured_surfaces.size()); + + auto listEdgesCut = classifier.list_Possible_Cuts(); + //Do 1 cut + bl.cut_sheet(listEdgesCut.front().first,listEdgesCut.front().second); + + + classifier.classify(); + + listEdgesCut = classifier.list_Possible_Cuts(); + //Do 1 cut + bl.cut_sheet(listEdgesCut.front().first,listEdgesCut.front().second); + + classifier.classify(); + + listEdgesCut = classifier.list_Possible_Cuts(); + //Do 1 cut + bl.cut_sheet(listEdgesCut.front().first,listEdgesCut.front().second); + + classifier.classify(); + + listEdgesCut = classifier.list_Possible_Cuts(); + //Do 1 cut + bl.cut_sheet(listEdgesCut.front().first,listEdgesCut.front().second); + + + gmds::Mesh m(gmds::MeshModel(gmds::DIM3|gmds::N|gmds::E|gmds::F|gmds::R|gmds::E2N|gmds::F2N|gmds::R2N)); + bl.convert_to_mesh(m); + + + gmds::IGMeshIOService ios(&m); + gmds::VTKWriter vtk_writer(&ios); + vtk_writer.setCellOptions(gmds::N|gmds::R); + vtk_writer.setDataOptions(gmds::N|gmds::R); + vtk_writer.write("cb2_debug_blocking.vtk"); + gmds::VTKWriter vtk_writer_edges(&ios); + vtk_writer_edges.setCellOptions(gmds::N|gmds::E); + vtk_writer_edges.setDataOptions(gmds::N|gmds::E); + vtk_writer_edges.write("cb2_debug_blocking_edges.vtk"); + gmds::VTKWriter vtk_writer_faces(&ios); + vtk_writer_faces.setCellOptions(gmds::N|gmds::F); + vtk_writer_faces.setDataOptions(gmds::N|gmds::F); + vtk_writer_faces.write("cb2_debug_blocking_faces.vtk"); + + +} + + diff --git a/rlBlocking/CMakeLists.txt b/rlBlocking/CMakeLists.txt index d987d2e8e..ce7fe6055 100644 --- a/rlBlocking/CMakeLists.txt +++ b/rlBlocking/CMakeLists.txt @@ -9,11 +9,24 @@ set(GMDS_INC inc/gmds/rlBlocking/BlockingQuality.h inc/gmds/rlBlocking/LinkerBlockingGeom.h inc/gmds/rlBlocking/ValidBlocking.h + inc/gmds/rlBlocking/MCTSAlgorithm.h + inc/gmds/rlBlocking/MCTSState.h + inc/gmds/rlBlocking/MCTSTree.h + inc/gmds/rlBlocking/MCTSMove.h + inc/gmds/rlBlocking/MCTSMovePolycube.h + inc/gmds/rlBlocking/MCTSStatePolycube.h + inc/gmds/rlBlocking/MCTSAgent.h ) set(GMDS_SRC src/BlockingQuality.cpp src/LinkerBlockingGeom.cpp - src/ValidBlocking.cpp) + src/ValidBlocking.cpp + src/MCTSTree.cpp + src/MCTSAlgorithm.cpp + src/MCTSMovePolycube.cpp + src/MCTSStatePolycube.cpp + src/MCTSAgent.cpp +) #============================================================================== add_library(${GMDS_LIB} ${GMDS_INC} ${GMDS_SRC}) #============================================================================== diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h b/rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h new file mode 100644 index 000000000..f85f81346 --- /dev/null +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h @@ -0,0 +1,26 @@ +#ifndef GMDS_MCTSAGENT_H +#define GMDS_MCTSAGENT_H + +#include +#include +#include +/*----------------------------------------------------------------------------------------*/ +namespace gmds { +/*----------------------------------------------------------------------------------------*/ +class LIB_GMDS_RLBLOCKING_API MCTSAgent +{ + // example of an agent based on the MCTS_tree. One can also use the tree directly. + MCTSTree *tree; + int max_iter, max_seconds, max_same_quality; + + public: + MCTSAgent(MCTSState *starting_state, int max_iter = 100000, int max_seconds = 30, int max_same_quality=3); + ~MCTSAgent(); + const MCTSMove *genmove(); + const MCTSState *get_current_state() const; + void feedback() const {tree->print_stats();} +}; +} +/*----------------------------------------------------------------------------------------*/ +#endif // GMDS_MCTSAGENT_H +/*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSAlgorithm.h b/rlBlocking/inc/gmds/rlBlocking/MCTSAlgorithm.h new file mode 100644 index 000000000..083ff9fd4 --- /dev/null +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSAlgorithm.h @@ -0,0 +1,47 @@ +/*----------------------------------------------------------------------------------------*/ +#ifndef GMDS_MCTSALGORITHM_H +#define GMDS_MCTSALGORITHM_H +/*----------------------------------------------------------------------------------------*/ +#include "LIB_GMDS_RLBLOCKING_export.h" +#include +#include +#include +/*----------------------------------------------------------------------------------------*/ +namespace gmds { +/*----------------------------------------------------------------------------------------*/ +/** @class MCTSAlgorithm + * @brief Class that provides .... + */ +class LIB_GMDS_RLBLOCKING_API MCTSAlgorithm +{ + MCTSTree *tree; + int max_iter, max_seconds,max_same_quality; + public: + + /*------------------------------------------------------------------------*/ + /** @brief Constructor. + * @param + */ + MCTSAlgorithm(gmds::cad::GeomManager *AGeom,gmds::blocking::CurvedBlocking *ABlocking,int max_iter = 100000, int max_seconds = 30,int max_same_quality = 10); + + /*------------------------------------------------------------------------*/ + /** @brief Destructor. */ + ~MCTSAlgorithm(); + + /*------------------------------------------------------------------------*/ + /** @brief Performs the MCTS algorithm + */ + void execute(); + + private: + /** a geom */ + gmds::cad::GeomManager *m_geom; + /** a blocking */ + gmds::blocking::CurvedBlocking *m_blocking; + +}; +/*----------------------------------------------------------------------------*/ +} +/*----------------------------------------------------------------------------------------*/ +#endif // GMDS_MCTSALGORITHM_H +/*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h b/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h new file mode 100644 index 000000000..8b276a22a --- /dev/null +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h @@ -0,0 +1,32 @@ +// +// Created by bourmaudp on 02/12/22. +// +/*----------------------------------------------------------------------------------------*/ +#ifndef GMDS_MCTSMOVE_H +#define GMDS_MCTSMOVE_H +/*----------------------------------------------------------------------------------------*/ +#include "LIB_GMDS_RLBLOCKING_export.h" +#include +/*----------------------------------------------------------------------------------------*/ +namespace gmds { +/*----------------------------------------------------------------------------------------*/ +/** @class MCTSMove + * @brief Structure that provides .... + */ +struct LIB_GMDS_RLBLOCKING_API MCTSMove { + /*------------------------------------------------------------------------*/ + /** @brief Destructor + */ + virtual ~MCTSMove() = default; + /*------------------------------------------------------------------------*/ + /** @brief Overloaded == + */ + virtual bool operator==(const MCTSMove& AOther) const = 0; + virtual std::string sprint() const { return "Not implemented"; } + virtual void print() const =0; // and optionally this +}; +/*----------------------------------------------------------------------------*/ +} +/*----------------------------------------------------------------------------------------*/ +#endif // GMDS_MCTSMOVE_H +/*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h b/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h new file mode 100644 index 000000000..0af5a6449 --- /dev/null +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h @@ -0,0 +1,38 @@ +/*----------------------------------------------------------------------------------------*/ +#ifndef GMDS_MCTSMOVE_POLYCUBE_H +#define GMDS_MCTSMOVE_POLYCUBE_H +/*----------------------------------------------------------------------------------------*/ +#include "LIB_GMDS_RLBLOCKING_export.h" +#include +#include +/*----------------------------------------------------------------------------------------*/ +namespace gmds { +/*----------------------------------------------------------------------------------------*/ +/** @class MCTSMove + * @brief Structure that provides .... + */ +struct LIB_GMDS_RLBLOCKING_API MCTSMovePolycube: public MCTSMove { + /*------------------------------------------------------------------------*/ + /** @brief Destructor + */ + ~MCTSMovePolycube(); + /*------------------------------------------------------------------------*/ + TCellID m_AIdEdge; + TCellID m_AIdBlock; + double m_AParamCut; + /** @brief if typeMove=2: delete block, typeMove=1 cut block + */ + unsigned int m_typeMove; + + /** @brief Overloaded == + */ + MCTSMovePolycube(TCellID AIdEdge = -1,TCellID AIdBlock = -1 , double AParamCut = 0,unsigned int ATypeMove = -1); + bool operator==(const MCTSMove& AOther) const; + void print() const; + +}; +/*----------------------------------------------------------------------------*/ +} +/*----------------------------------------------------------------------------------------*/ +#endif // GMDS_MCTSMOVE_POLYCUBE_H +/*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSState.h b/rlBlocking/inc/gmds/rlBlocking/MCTSState.h new file mode 100644 index 000000000..92301494e --- /dev/null +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSState.h @@ -0,0 +1,72 @@ +// +// Created by bourmaudp on 02/12/22. +// +/*----------------------------------------------------------------------------------------*/ +#ifndef GMDS_MCTSSTATE_H +#define GMDS_MCTSSTATE_H +/*----------------------------------------------------------------------------------------*/ +#include "LIB_GMDS_RLBLOCKING_export.h" +#include +#include +/*----------------------------------------------------------------------------------------*/ +#include +/*----------------------------------------------------------------------------------------*/ +namespace gmds { +/*----------------------------------------------------------------------------------------*/ +/** @class MCTSState + * @brief Class that provides the interface to be implemented for performing the + * MCST algorithm + */ +class LIB_GMDS_RLBLOCKING_API MCTSState { + public: + /*--------------------------------------------------------------------*/ + /** @enum Status code for rollout execution + */ + typedef enum { + WIN, + LOSE, + DRAW + } ROLLOUT_STATUS; + /*------------------------------------------------------------------------*/ + /** @brief Destructor + */ + virtual ~MCTSState() = default; + /*------------------------------------------------------------------------*/ + /** @brief Gives the set of actions that can be tried from the current state + */ + virtual std::deque *actions_to_try() const = 0; + /*------------------------------------------------------------------------*/ + /** @brief Performs the @p AMove to change of states + * @param[in] AMove the movement to apply to get to a new state + */ + virtual MCTSState *next_state(const MCTSMove *AMove) const = 0; + /*------------------------------------------------------------------------*/ + /** @brief Rollout from this state (random simulation) + * @return the rollout status + */ + virtual double state_rollout() const = 0; + /*------------------------------------------------------------------------*/ + /** @brief check the result of a terminal state + * @return the value of the result: Win, Lose, Draw + */ + virtual ROLLOUT_STATUS result_terminal() const = 0; + /*------------------------------------------------------------------------*/ + /** @brief Indicate if we have a terminal state (win=true, fail=false) + * @return true if we have a leaf (in the sense of a traditional tree) + */ + virtual bool is_terminal() const = 0; + /*------------------------------------------------------------------------*/ + /** @brief Indicate if we have a terminal state (win=true, fail=false) + * @return true if we have a leaf (in the sense of a traditional tree) + */ + virtual double get_quality() const = 0; + + virtual void print() const { + std::cout << "Printing not implemented" << std::endl; + } +}; +/*----------------------------------------------------------------------------*/ +} +/*----------------------------------------------------------------------------------------*/ +#endif // GMDS_MCTSSTATE_H +/*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h b/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h new file mode 100644 index 000000000..bd4341015 --- /dev/null +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h @@ -0,0 +1,87 @@ +/*----------------------------------------------------------------------------------------*/ +#ifndef GMDS_MCTSSTATE_POLYCUBE_H +#define GMDS_MCTSSTATE_POLYCUBE_H +/*----------------------------------------------------------------------------------------*/ +#include "LIB_GMDS_RLBLOCKING_export.h" +#include +#include +#include +#include +/*----------------------------------------------------------------------------------------*/ +namespace gmds { +/*----------------------------------------------------------------------------------------*/ +/** @class MCTSState + * @brief Class that provides the interface to be implemented for performing the + * MCST algorithm + */ +class LIB_GMDS_RLBLOCKING_API MCTSStatePolycube: public MCTSState{ + + public: + /*------------------------------------------------------------------------*/ + /** @brief Constructore + */ + MCTSStatePolycube(gmds::cad::GeomManager *Ageom, gmds::blocking::CurvedBlocking *ABlocking, + std::vector AHist); + /*------------------------------------------------------------------------*/ + /** @brief Destructor + */ + ~MCTSStatePolycube(); + /*------------------------------------------------------------------------*/ + /** @brief Gives the set of actions that can be tried from the current state + */ + std::deque *actions_to_try() const ; + /*------------------------------------------------------------------------*/ + /** @brief Performs the @p AMove to change of states + * @param[in] AMove the movement to apply to get to a new state + */ + MCTSState *next_state(const MCTSMove *AMove) const; + /*------------------------------------------------------------------------*/ + /** @brief Rollout from this state (random simulation) + * @return the rollout status + */ + double state_rollout() const; + + /** @brief check the history of qualities + * @return nb of same quality from the history + */ + int check_nb_same_quality() const; + /** @brief check the result of a terminal state + * @return Win = all elements are capt, Lose: parent_quality < enfant_quality, + * Draw : same quality for a long time + */ + ROLLOUT_STATUS result_terminal() const; + /*------------------------------------------------------------------------*/ + /** @brief Indicate if we have a terminal state (win=true, fail=false) + * @return true if we have a leaf (in the sense of a traditional tree) + */ + bool is_terminal() const; + /** @brief return the blocking quality + * */ + double get_quality() const; + /** @brief return the geom */ + gmds::cad::GeomManager *get_geom(); + /** @brief return the current blocking */ + gmds::blocking::CurvedBlocking *get_blocking(); + /** @brief return the current classifier */ + gmds::blocking::CurvedBlockingClassifier *get_class(); + /** @brief return the current classification */ + gmds::blocking::ClassificationErrors get_errors(); + /** @brief return the history of the parents quality */ + std::vector get_history() const; + + /** @brief update the classification of a state */ + void update_class(); + + private : + /** @brief the curved blocking of the current state */ + gmds::blocking::CurvedBlocking* m_blocking; + gmds::cad::GeomManager* m_geom; + gmds::blocking::CurvedBlockingClassifier* m_class_blocking; + gmds::blocking::ClassificationErrors m_class_errors; + std::vector m_history; +}; +/*----------------------------------------------------------------------------*/ +} +/*----------------------------------------------------------------------------------------*/ +#endif // GMDS_MCTSSTATE_POLYCUBE_H +/*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h b/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h new file mode 100644 index 000000000..e8e2328c0 --- /dev/null +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h @@ -0,0 +1,141 @@ +/*----------------------------------------------------------------------------------------*/ +#ifndef GMDS_MCTSTREE_H +#define GMDS_MCTSTREE_H +/*----------------------------------------------------------------------------------------*/ +#include "LIB_GMDS_RLBLOCKING_export.h" +#include +#include +#include +#include +#include +/*----------------------------------------------------------------------------------------*/ +namespace gmds { +/*----------------------------------------------------------------------------------------*/ +/** @class MCTSNode + * @brief Class that provides .... + */ +class LIB_GMDS_RLBLOCKING_API MCTSNode { + /** @brief Yes if the node have no childs */ + bool terminal; + /** @brief Number of nodes in the tree from the node. */ + unsigned int size; + /** @brief number of parent nodes having the same quality*/ + unsigned int nb_same_quality; + /** @brief Number of visits*/ + unsigned int number_of_simulations; + /** @brief e.g. number of wins (could be int but double is more general if we use evaluation functions)*/ + double score; + /** @brief state of the current node */ + MCTSState *state; + /** @brief move to get here from parent node's state*/ + const MCTSMove *move; + /** @brief the chilren for the current node */ + std::vector *children; + /** @brief the parent for the current node*/ + MCTSNode *parent; + /** @brief queue of untried actions*/ + std::deque *untried_actions; + /** @brief update the nb simulations and the score after a rollout*/ + void backpropagate(double w, int n); + public: + + /*------------------------------------------------------------------------*/ + /** @brief Constructor. + * @param AParent the parent the node + * @param AMove the action to access at this node + */ + MCTSNode(MCTSNode *AParent, const MCTSMove *AMove, MCTSState *AState); + + /*------------------------------------------------------------------------*/ + /** @brief Destructor. */ + virtual ~MCTSNode(); + + /** @brief Check if the node is fully expanded */ + bool is_fully_expanded() const; + /** @brief Check if the node is terminal. + * @param Number max of parents nodes with the same quality + * */ + bool is_terminal() const; + /** @brief Return the different moves/actions possible for a node */ + const MCTSMove *get_move() const; + /** @brief Return the size. */ + unsigned int get_size() const; + /** @brief Expand the node. */ + void expand(); + /** @brief Do a rollout. */ + void rollout(); + /** @brief Select the most promising child of the root node */ + MCTSNode *select_best_child(double c) const; + /** @brief Find child with this m and delete all others. + * @param m the selected move + * @return the next root*/ + MCTSNode *advance_tree(const MCTSMove *m); + /** @brief Return the state of the node. */ + const MCTSState *get_current_state() const; + /** @brief Return the children of the node. */ + std::vector *get_children(); + /** @brief Print the tree and the stats. */ + void print_stats() const; + /** @brief Calculate the q rate of a node. It's: wins-looses */ + double q_rate() const; + /** @brief Calculate UCT. */ + double calculate_UCT() const; + /** @brief Calculate winrate. */ + double calculate_winrate() const; + + + + private: + /** a mesh */ + //Mesh* m_mesh; +}; + +/*----------------------------------------------------------------------------------------*/ +/** @class MCTSTree + * @brief Class that provides .... + */ +class LIB_GMDS_RLBLOCKING_API MCTSTree +{ + MCTSNode *root; + public: + + /*------------------------------------------------------------------------*/ + /** @brief Constructor. + * @param + */ + MCTSTree(MCTSState *starting_state); + + /*------------------------------------------------------------------------*/ + /** @brief Destructor. */ + virtual ~MCTSTree(); + + /** @brief select child node to expand according to tree policy (UCT). + * @param c exploration parameter, theoretically equal to √2 + * @return a node + */ + MCTSNode *select(double c=1.41); + MCTSNode *select_best_child(); + void grow_tree(int max_iter, double max_time_in_seconds); + /** @brief if the move is applicable advance the tree, else start over + * @param move the move to do + * */ + void advance_tree(const MCTSMove *move); + /** @brief get the size of the tree. */ + unsigned int get_size() const; + /** @brief get the size of the tree. */ + const MCTSState *get_current_state() const; + /** @brief Print stats. */ + void print_stats() const; + + + private: + /** a mesh */ + //Mesh* m_mesh; +}; +/*----------------------------------------------------------------------------*/ + +/*----------------------------------------------------------------------------*/ +} +/*----------------------------------------------------------------------------------------*/ +#endif // GMDS_MCTSTREE_H +/*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/src/MCTSAgent.cpp b/rlBlocking/src/MCTSAgent.cpp new file mode 100644 index 000000000..01bfd526e --- /dev/null +++ b/rlBlocking/src/MCTSAgent.cpp @@ -0,0 +1,47 @@ +/*----------------------------------------------------------------------------*/ +#include +/*----------------------------------------------------------------------------*/ +using namespace gmds; +/*----------------------------------------------------------------------------*/ +MCTSAgent::MCTSAgent(gmds::MCTSState *starting_state, int max_iter, int max_seconds, int max_same_quality) + :max_iter(max_iter),max_seconds(max_seconds),max_same_quality(max_same_quality) +{ + tree = new MCTSTree(starting_state); +} + +/*----------------------------------------------------------------------------*/ +MCTSAgent::~MCTSAgent(){ + delete tree; +} +/*----------------------------------------------------------------------------*/ +const MCTSMove *MCTSAgent::genmove() +{ + // If game ended from opponent move, we can't do anything + if (tree->get_current_state()->is_terminal()) { + return NULL; + } +#ifdef DEBUG + std::cout << "___ DEBUG ______________________" << endl + << "Growing tree..." << std::endl; +#endif + tree->grow_tree(max_iter, max_seconds); +#ifdef DEBUG + cout << "Tree size: " << tree->get_size() << endl + << "________________________________" << endl; +#endif + MCTSNode *best_child = tree->select_best_child(); + if (best_child == NULL) { + std::cerr << "Warning: Tree root has no children! Possibly terminal node!" << std::endl; + return NULL; + } + const MCTSMove *best_move = best_child->get_move(); + tree->advance_tree(best_move); + return best_move; +} +/*----------------------------------------------------------------------------*/ +const + MCTSState *MCTSAgent::get_current_state() const +{ + return tree->get_current_state(); +} +/*----------------------------------------------------------------------------*/ diff --git a/rlBlocking/src/MCTSAlgorithm.cpp b/rlBlocking/src/MCTSAlgorithm.cpp new file mode 100644 index 000000000..4338cbb17 --- /dev/null +++ b/rlBlocking/src/MCTSAlgorithm.cpp @@ -0,0 +1,47 @@ +/*----------------------------------------------------------------------------*/ +#include +/*----------------------------------------------------------------------------*/ +#include +#include +/*----------------------------------------------------------------------------*/ +using namespace gmds; +/*----------------------------------------------------------------------------*/ +MCTSAlgorithm::MCTSAlgorithm(gmds::cad::GeomManager* AGeom,gmds::blocking::CurvedBlocking* ABlocking,int max_iter, int max_seconds, int max_same_quality) + : m_geom(AGeom),m_blocking(ABlocking),max_iter(max_iter), max_seconds(max_seconds),max_same_quality(max_same_quality) +{ std::vector hist_empty; + MCTSStatePolycube *init_state = new MCTSStatePolycube(m_geom,m_blocking,hist_empty); + tree = new MCTSTree(init_state);} +/*----------------------------------------------------------------------------*/ +MCTSAlgorithm::~MCTSAlgorithm(){;} +/*----------------------------------------------------------------------------*/ +void MCTSAlgorithm::execute() +{ + std::cout<<"==========================================================="<m_geom, this->m_blocking, std::vector ()); + //state->print(); // IMPORTANT: state will be garbage after advance_tree() + MCTSAgent agent(state, 100); + do { + agent.feedback(); + agent.genmove(); + // TODO: This way we don't check if the enemy move ends the game but it's our responsibility to check that, not the tree's... + const MCTSState *new_state = agent.get_current_state(); + new_state->print(); +// if (new_state->is_terminal()) { +// winner = ((const TicTacToe_state *) new_state)->get_winner(); +// break; +// } + done = new_state->is_terminal(); + } while (!done); + + + std::cout<<"==========================================================="< +/*----------------------------------------------------------------------------*/ +using namespace gmds; +/*----------------------------------------------------------------------------*/ +MCTSMovePolycube::~MCTSMovePolycube() +{} +/*----------------------------------------------------------------------------*/ +MCTSMovePolycube::MCTSMovePolycube(TCellID AIdEdge, TCellID AIdBlock, double AParamCut, unsigned int ATypeMove) + :m_AIdEdge(AIdEdge),m_AIdBlock(AIdBlock), m_AParamCut(AParamCut),m_typeMove(ATypeMove) +{} +/*----------------------------------------------------------------------------*/ +bool +MCTSMovePolycube::operator==(const gmds::MCTSMove &AOther) const +{ + const MCTSMovePolycube &o = (const MCTSMovePolycube &) AOther; // Note: Casting necessary + return m_AIdEdge == o.m_AIdEdge && m_AIdBlock == o.m_AIdBlock + && m_AParamCut == o.m_AParamCut && m_typeMove== o.m_typeMove; +} +/*----------------------------------------------------------------------------*/ +void MCTSMovePolycube::print() const +{ + std::cout<<"m_AIdEdge : "<< m_AIdEdge<<" ; m_AIdBlock : "< +#include +/*----------------------------------------------------------------------------*/ +using namespace gmds; + + +/*----------------------------------------------------------------------------*/ +MCTSStatePolycube::MCTSStatePolycube(gmds::cad::GeomManager* AGeom, gmds::blocking::CurvedBlocking* ABlocking, + std::vector hist ) + :m_geom(AGeom),m_history(hist) +{ + m_blocking = new blocking::CurvedBlocking(*ABlocking); + gmds::blocking::CurvedBlockingClassifier classifier(m_blocking); + m_class_blocking = new blocking::CurvedBlockingClassifier(classifier); + m_class_errors = m_class_blocking->classify(0.2); + ;} +/*----------------------------------------------------------------------------*/ +MCTSStatePolycube::~MCTSStatePolycube() noexcept +{ + delete m_class_blocking; + delete m_blocking; +} +/*----------------------------------------------------------------------------*/ +std::deque * +MCTSStatePolycube::actions_to_try() const +{ + std::deque *Q = new std::deque(); + if (m_class_errors.non_captured_points.size()== 0){ + std::cout<<"POINTS CAPT :"<get_all_id_blocks(); + for(auto b : blocks){ + Q->push_back(new MCTSMovePolycube(NullID,b,0,2)); + } + } + else{ + std::cout<<"NB CURVES CAPT :"<< m_class_errors.non_captured_curves.size()<list_Possible_Cuts(); + for(auto c : listPossibleCuts){ + Q->push_back(new MCTSMovePolycube(c.first,NullID,c.second,1)); + } + } + } + else{ + std::cout<<"POINTS NO CAPT :"<list_Possible_Cuts(); + for(auto c : listPossibleCuts){ + Q->push_back(new MCTSMovePolycube(c.first,NullID,c.second,1)); + } + } + std::cout<<"LIST ACTIONS :"<print(); + } + return Q; +} +/*----------------------------------------------------------------------------*/ +MCTSState + *MCTSStatePolycube::next_state(const gmds::MCTSMove *AMove) const +{ + std::cout<<"==================== EXECUTE ACTION ! ===================="< hist_update = get_history(); + hist_update.push_back(get_quality()); + gmds::blocking::CurvedBlocking* new_b = new gmds::blocking::CurvedBlocking(*m_blocking); + MCTSStatePolycube *new_state = new MCTSStatePolycube(this->m_geom,new_b,hist_update); + if(m->m_typeMove == 2){ + //TODO ERROR, sometimes, block select not in the current blocks list...Check why !!! + std::cout<<"LIST BLOCK BLOCKING : "<get_all_id_blocks()){ + std::cout<m_AIdBlock){ + b_in_list = true; + break; + } + } + if(b_in_list){ + std::cout<<"BLOCK A DELETE :"<m_AIdBlock<m_blocking->remove_block(m->m_AIdBlock); + } + else{ + std::cout<<"BLOCK A DELETE :"<get_all_id_blocks().front()<m_blocking->remove_block(m_blocking->get_all_id_blocks().front()); + } + + new_state->update_class(); + //SAVE Blocking vtk + std::string name_save_folder = "/home/bourmaudp/Documents/PROJETS/gmds/gmds_Correction_Class_Dev/saveResults/cb2/"; + std::string id_act = std::to_string(m->m_AIdEdge); + std::string name_file = "cb2_action"+ id_act +".vtk"; + new_state->m_blocking->save_vtk_blocking(name_save_folder+name_file); + return new_state; + } + else if(m->m_typeMove ==1) { + new_state->m_blocking->cut_sheet(m->m_AIdEdge,m->m_AParamCut); + new_state->update_class(); + //SAVE Blocking vtk + std::string name_save_folder = "/home/bourmaudp/Documents/PROJETS/gmds/gmds_Correction_Class_Dev/saveResults/cb2/"; + std::string id_act = std::to_string(m->m_AIdEdge); + std::string name_file = "cb2_action"+ id_act +".vtk"; + new_state->m_blocking->save_vtk_blocking(name_save_folder+name_file); + return new_state; + } + else{ + std::cerr << "Warning: Bad type move ! \n Type move :" << m->m_typeMove << " & ID " << m->m_AIdEdge<< std::endl; + return new_state; + } + std::string name_save_folder = "/home/bourmaudp/Documents/PROJETS/gmds/gmds_Correction_Class_Dev/saveResults/"; + std::string id_act = std::to_string(m->m_AIdEdge); + std::string name_file = "M1_action"+ id_act +".vtk"; + m_blocking->save_vtk_blocking(name_save_folder+name_file); + +} +/*----------------------------------------------------------------------------*/ +double +MCTSStatePolycube::state_rollout() const +{ + std::cout<<"STATE ROLLOUT"< *list_action = actions_to_try(); + //Get first move/action + //But, maybe, better to take rand move if its a delete move... + MCTSMove *firstMove = list_action->front(); //TODO: implement random move when only delete moves is possible + list_action->pop_front(); + + MCTSStatePolycube *old = curstate; + std::cout<<"===== SIZE UNTRIED ACTIONS : "<size()+1<<" ====="<next_state(firstMove); + if (!first) { + delete old; + } + first = false; + } while (!curstate->is_terminal()); + + if(curstate->result_terminal() == WIN){ + res=1; + } + else if (curstate->result_terminal() == LOSE) { + res=-1; + } + else{ + //Draw + res=0; + } + delete curstate; + return res; +} +/*----------------------------------------------------------------------------*/ +MCTSStatePolycube::ROLLOUT_STATUS +MCTSStatePolycube::result_terminal() const +{ + if (m_class_errors.non_captured_points.empty() && m_class_errors.non_captured_curves.empty() && m_class_errors.non_captured_surfaces.empty()) { + return WIN; + } + else if (check_nb_same_quality() >= 3){ + return DRAW; + } + else if (!m_history.empty() && m_history.back() < this->get_quality()){ + return LOSE; + } + else if (this->actions_to_try()->empty()){ + return LOSE; + } + std::cerr << "ERROR: NOT terminal state ..." << std::endl; + return DRAW; +} +/*----------------------------------------------------------------------------*/ +int MCTSStatePolycube::check_nb_same_quality() const +{ + int nb_same = 0; + double state_quality = get_quality(); + for (int i = m_history.size() - 1; i >= 0; --i) { + if(m_history[i] == state_quality){ + nb_same++; + } + else{ + break; + } + } + return nb_same; +} +/*----------------------------------------------------------------------------*/ +bool +MCTSStatePolycube::is_terminal() const +{ + if (m_class_errors.non_captured_points.empty() && m_class_errors.non_captured_curves.empty() && m_class_errors.non_captured_surfaces.empty()) { + return true; + } + else if(check_nb_same_quality() >= 3){ + return true; + } + else if(!m_history.empty() && m_history.back() < this->get_quality()){ + return true; + } + else if(this->actions_to_try()->empty()){ + return true; + } + else { + std::cout<<"NOT TERMINAL STATE"<geom_model(); +} + +/*----------------------------------------------------------------------------*/ +gmds::blocking::CurvedBlocking *MCTSStatePolycube::get_blocking() +{ + return m_blocking; +} + +/*----------------------------------------------------------------------------*/ +gmds::blocking::CurvedBlockingClassifier *MCTSStatePolycube::get_class() +{ + return m_class_blocking; +} +/*----------------------------------------------------------------------------*/ +gmds::blocking::ClassificationErrors MCTSStatePolycube::get_errors() +{ + return m_class_errors; +} +/*----------------------------------------------------------------------------*/ +std::vector MCTSStatePolycube::get_history() const +{ + return m_history; +} +/*----------------------------------------------------------------------------*/ +void MCTSStatePolycube::update_class() +{ + gmds::blocking::CurvedBlockingClassifier classifier(m_blocking); + m_class_blocking = new blocking::CurvedBlockingClassifier(classifier); + m_class_errors = m_class_blocking->classify(0.2); +} +/*----------------------------------------------------------------------------*/ diff --git a/rlBlocking/src/MCTSTree.cpp b/rlBlocking/src/MCTSTree.cpp new file mode 100644 index 000000000..24044c6a4 --- /dev/null +++ b/rlBlocking/src/MCTSTree.cpp @@ -0,0 +1,287 @@ +/*----------------------------------------------------------------------------*/ +#include +/*----------------------------------------------------------------------------*/ +#include +#include +/*----------------------------------------------------------------------------*/ +using namespace gmds; +/*----------------------------------------------------------------------------*/ +/*-------------------------------- MCTS NODE ---------------------------------*/ +/*----------------------------------------------------------------------------*/ +MCTSNode::MCTSNode(gmds::MCTSNode *AParent, const gmds::MCTSMove *AMove, MCTSState *AState) + :parent(AParent), move(AMove),state(AState), score(0.0), number_of_simulations(0), size(0) +{ + children = new std::vector(); + untried_actions = state->actions_to_try(); + terminal = state->is_terminal(); + + + +} +/*----------------------------------------------------------------------------*/ +MCTSNode::~MCTSNode() { + delete state; + delete move; + for (auto *child : *children) { + delete child; + } + delete children; + while (!untried_actions->empty()) { + delete untried_actions->front(); // if a move is here then it is not a part of a child node and needs to be deleted here + untried_actions->pop_front(); + } + delete untried_actions; +} + +/*----------------------------------------------------------------------------*/ +void MCTSNode::expand() { + if (is_terminal()) { // can legitimately happen in end-game situations + rollout(); // keep rolling out, eventually causing UCT to pick another node to expand due to exploration + return; + } else if (is_fully_expanded()) { + std::cerr << "Warning: Cannot expanded this node any more!" << std::endl; + return; + } + // get next untried action + MCTSMove *next_move = untried_actions->front(); // get value + untried_actions->pop_front(); // remove it + MCTSState *next_state = state->next_state(next_move); + + if(state->get_quality() == next_state->get_quality()){ + const unsigned int nb_same_quality = this->nb_same_quality + 1; + } + else{ + const unsigned int nb_same_quality = 0; + } + // build a new MCTS node from it + MCTSNode *new_node = new MCTSNode(this,next_move,next_state); + // rollout, updating its stats + new_node->rollout(); + // add new node to tree + children->push_back(new_node); +} +/*----------------------------------------------------------------------------*/ +const MCTSState *MCTSNode::get_current_state() const +{ + return state; +} +/*----------------------------------------------------------------------------*/ +std::vector +*MCTSNode::get_children() +{ + return children; +} +/*----------------------------------------------------------------------------*/ +bool +MCTSNode::is_terminal() const +{ + return terminal; +} + +/*----------------------------------------------------------------------------*/ +const MCTSMove *MCTSNode::get_move() const { + return move; +} +/*----------------------------------------------------------------------------*/ +bool +MCTSNode::is_fully_expanded() const +{ + return is_terminal() || untried_actions->empty(); +} +/*----------------------------------------------------------------------------*/ +unsigned int MCTSNode::get_size() const { + return size; +} +/*----------------------------------------------------------------------------*/ +MCTSNode *MCTSNode::select_best_child(double c) const { + /** selects best child based on the winrate of whose turn it is to play */ + if (children->empty()) { + return NULL; + } + else if (children->size() == 1) return children->at(0); + else { + double uct, max = -1; + MCTSNode *argmax = NULL; + for (auto *child : *children) { + double winrate = child->score / ((double) child->number_of_simulations); + if (c > 0) { + uct = winrate + + c * sqrt(log((double) this->number_of_simulations) / ((double) child->number_of_simulations)); + } else { + uct = winrate; + } + if (uct > max) { + max = uct; + argmax = child; + } + } + return argmax; + } +} +/*----------------------------------------------------------------------------*/ +void +MCTSNode::rollout() +{ + double w = state->state_rollout(); + backpropagate(w, 1); +} +/*----------------------------------------------------------------------------*/ +void MCTSNode::backpropagate(double w, int n) { + score += w; + number_of_simulations += n; + if (parent != NULL) { + parent->size++; + parent->backpropagate(w, n); + } +} + +/*----------------------------------------------------------------------------*/ +MCTSNode *MCTSNode::advance_tree(const MCTSMove *m) { + //TODO + // Find child with this m and delete all others + MCTSNode *next = NULL; + for (auto *child: *children) { + if (*(child->move) == *(m)) { + next = child; + } else { + delete child; + } + } + // remove children from queue so that they won't be re-deleted by the destructor when this node dies (!) + this->children->clear(); + // if not found then we have to create a new node + if (next == NULL) { + // Note: UCT may lead to not fully explored tree even for short-term children due to terminal nodes being chosen + std::cout << "INFO: Didn't find child node. Had to start over." << std::endl; + MCTSState *next_state = state->next_state(m); + next = new MCTSNode(this, m,next_state); + } else { + next->parent = NULL; // make parent NULL + // IMPORTANT: m and next->move can be the same here if we pass the move from select_best_child() + // (which is what we will typically be doing). If not then it's the caller's responsibility to delete m (!) + } + // return the next root + return next; +} +/*----------------------------------------------------------------------------*/ +void MCTSNode::print_stats() const { +#define TOPK 10 + if (number_of_simulations == 0) { + std::cout << "Tree not expanded yet" << std::endl; + return; + } + std::cout << "___ INFO _______________________" << std::endl + << "Tree size: " << size << std::endl + << "Number of simulations: " << number_of_simulations << std::endl + << "Branching factor at root: " << children->size() << std::endl; + // Print the best move for a current node +// MCTSNode *bestChild; +// bool first = true; +// double winRateChild = 0; +// if(!children->empty()) { +// for (int i = 0; i < children->size(); i++) { +// if (first) { +// bestChild = children->at(i); +// winRateChild = bestChild->calculate_winrate(); +// first = false; +// } +// +// else if (winRateChild < children->at(i)->calculate_winrate()) { +// bestChild = children->at(i); +// winRateChild = bestChild->calculate_winrate(); +// } +// } +// std::cout << "Best Move :" << std::endl; +// bestChild->move->print(); +// } + std::cout << "________________________________" << std::endl; +} + + +/*----------------------------------------------------------------------------*/ +double MCTSNode::calculate_winrate() const { + return score / number_of_simulations; +} + +/*----------------------------------------------------------------------------*/ +/*-------------------------------- MCTS TREE --------------------------------*/ +/*----------------------------------------------------------------------------*/ +MCTSTree::MCTSTree(MCTSState* starting_state) +{ + assert(starting_state != NULL); + root = new MCTSNode(NULL, NULL, starting_state); +} +/*----------------------------------------------------------------------------*/ +MCTSTree::~MCTSTree() +{ + delete root; +} +/*----------------------------------------------------------------------------*/ +MCTSNode *MCTSTree::select(double c) { + MCTSNode *node = root; + while (!node->is_terminal()) { + if (!node->is_fully_expanded()) { + return node; + } else { + node = node->select_best_child(c); + } + } + return node; +} +/*----------------------------------------------------------------------------*/ +void MCTSTree::grow_tree(int max_iter, double max_time_in_seconds) { + MCTSNode *node; + double dt; +#ifdef DEBUG + std::cout << "Growing tree..." << std::endl; +#endif + time_t start_t, now_t; + time(&start_t); + for (int i = 0 ; i < max_iter ; i++){ + // select node to expand according to tree policy + node = select(); + // expand it (this will perform a rollout and backpropagate the results) + node->expand(); + // check if we need to stop + time(&now_t); + dt = difftime(now_t, start_t); + if (dt > max_time_in_seconds) { +#ifdef DEBUG + std::cout << "Early stopping: Made " << (i + 1) << " iterations in " << dt << " seconds." << std::endl; +#endif + break; + } + } +#ifdef DEBUG + time(&now_t); + dt = difftime(now_t, start_t); + cout << "Finished in " << dt << " seconds." << endl; +#endif +} +/*----------------------------------------------------------------------------*/ +MCTSNode *MCTSTree::select_best_child() { + return root->select_best_child(0.0); +} +/*----------------------------------------------------------------------------*/ +void MCTSTree::advance_tree(const MCTSMove *move) { + MCTSNode *old_root = root; + root = root->advance_tree(move); + delete old_root; // this won't delete the new root since we have emptied old_root's children +} + +/*----------------------------------------------------------------------------*/ +unsigned int MCTSTree::get_size() const { + return root->get_size(); +} +/*----------------------------------------------------------------------------*/ +const MCTSState *MCTSTree::get_current_state() const +{ + return root->get_current_state(); +} +/*----------------------------------------------------------------------------*/ +void MCTSTree::print_stats() const +{ + root->print_stats(); +} +/*----------------------------------------------------------------------------*/ + diff --git a/rlBlocking/src/main_rlBlocking.cpp b/rlBlocking/src/main_rlBlocking.cpp index 5cfce3aca..789c7546a 100644 --- a/rlBlocking/src/main_rlBlocking.cpp +++ b/rlBlocking/src/main_rlBlocking.cpp @@ -44,7 +44,7 @@ int main(int argc, char* argv[]) gmds::Mesh vol_mesh(gmds::MeshModel(gmds::DIM3 | gmds::R | gmds::F | gmds::E | gmds::N | gmds::R2N | gmds::R2F | gmds::R2E | gmds::F2N | gmds::F2R | gmds::F2E | gmds::E2F | gmds::E2N | gmds::N2E)); - std::string vtk_file = "/home/bourmaudp/Documents/mambo-master/Basic/vtk/B0.vtk"; + std::string vtk_file = "/home/bourmaudp/Documents/mambo-master/Basic/vtk/cb3.vtk"; gmds::IGMeshIOService ioServiceA(&vol_mesh); gmds::VTKReader vtkReaderA(&ioServiceA); vtkReaderA.setCellOptions(gmds::N | gmds::R); diff --git a/rlBlocking/tst/BlockQualityTestSuite.h b/rlBlocking/tst/BlockQualityTestSuite.h index 6ec4c98aa..94e7e3b21 100644 --- a/rlBlocking/tst/BlockQualityTestSuite.h +++ b/rlBlocking/tst/BlockQualityTestSuite.h @@ -1,8 +1,6 @@ #ifndef GMDS_BLOCKQUALITYTESTSUITE_H #define GMDS_BLOCKQUALITYTESTSUITE_H -#endif // GMDS_BLOCKQUALITYTESTSUITE_H - // // Created by ledouxf on 1/22/19. @@ -102,3 +100,5 @@ TEST(BlockQualityTestSuite, test_Rubiks) ASSERT_EQ(2, 2);//linker.getGeomId(n1)); } + +#endif // GMDS_BLOCKQUALITYTESTSUITE_H diff --git a/rlBlocking/tst/CMakeLists.txt b/rlBlocking/tst/CMakeLists.txt index 67aaac829..1f8376e84 100644 --- a/rlBlocking/tst/CMakeLists.txt +++ b/rlBlocking/tst/CMakeLists.txt @@ -1,5 +1,6 @@ add_executable(GMDS_BLOCKINGQUALITY_TEST BlockQualityTestSuite.h + MCTSTestSuite.h main_test.cpp) target_link_libraries(GMDS_BLOCKINGQUALITY_TEST PUBLIC diff --git a/rlBlocking/tst/MCTSTestSuite.h b/rlBlocking/tst/MCTSTestSuite.h new file mode 100644 index 000000000..3d40cb22d --- /dev/null +++ b/rlBlocking/tst/MCTSTestSuite.h @@ -0,0 +1,57 @@ +#ifndef GMDS_MCTSTESTSUITE_H +#define GMDS_MCTSTESTSUITE_H +/*----------------------------------------------------------------------------*/ +#include +/*----------------------------------------------------------------------------*/ +#include +#include +/*----------------------------------------------------------------------------*/ +using namespace gmds; +/*----------------------------------------------------------------------------*/ +/**@brief setup function that initialize a geometric model using the faceted + * representation and an input vtk file name. The vtk file must contain a + * tetrahedral mesh + * + * @param AGeomModel geometric model we initialize + * @param AFileName vtk filename + */ +void set_up_MCTS(gmds::cad::FACManager* AGeomModel, const std::string AFileName) +{ + gmds::Mesh vol_mesh(gmds::MeshModel(gmds::DIM3 | gmds::R | gmds::F | gmds::E | gmds::N | gmds::R2N | gmds::R2F | gmds::R2E | gmds::F2N | gmds::F2R | gmds::F2E + | gmds::E2F | gmds::E2N | gmds::N2E)); + std::string dir(TEST_SAMPLES_DIR); + std::string vtk_file = dir +"/"+ AFileName; + gmds::IGMeshIOService ioService(&vol_mesh); + gmds::VTKReader vtkReader(&ioService); + vtkReader.setCellOptions(gmds::N | gmds::R); + vtkReader.read(vtk_file); + gmds::MeshDoctor doc(&vol_mesh); + doc.buildFacesAndR2F(); + doc.buildEdgesAndX2E(); + doc.updateUpwardConnectivity(); + AGeomModel->initFrom3DMesh(&vol_mesh); + +} +/*----------------------------------------------------------------------------*/ +TEST(MCTSTestSuite, testExAglo) +{ + + gmds::cad::FACManager geom_model; + set_up_MCTS(&geom_model,"cb2.vtk"); + gmds::blocking::CurvedBlocking bl(&geom_model,true); + bl.save_vtk_blocking("/home/bourmaudp/Documents/PROJETS/gmds/gmds_Correction_Class_Dev/saveResults/cb2/cb2_init_blocking.vtk"); + + gmds::blocking::CurvedBlockingClassifier classifier(&bl); + std::cout<<"==================== BEGIN TEST : ===================="<execute(); + + + std::cout<<"==================== END TEST ! ===================="<